From af2a5544875b23b3b62fb6d4f3bf432828720008 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Oct 2015 12:42:10 -0700 Subject: [PATCH 001/139] [SPARK-10337] [SQL] fix hive views on non-hive-compatible tables. add a new config to deal with this special case. Author: Wenchen Fan Closes #8990 from cloud-fan/view-master. --- .../scala/org/apache/spark/sql/SQLConf.scala | 15 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 23 +++ .../org/apache/spark/sql/hive/HiveQl.scala | 164 +++++++++++++++--- .../sql/hive/client/ClientInterface.scala | 13 +- .../spark/sql/hive/client/ClientWrapper.scala | 31 ++++ .../hive/execution/CreateViewAsSelect.scala | 97 +++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 117 +++++++++++++ 7 files changed, 433 insertions(+), 27 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e7bbc7d5db493..8f0f8910b36ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -319,6 +319,15 @@ private[spark] object SQLConf { doc = "When true, some predicates will be pushed down into the Hive metastore so that " + "unmatching partitions can be eliminated earlier.") + val CANONICALIZE_VIEW = booleanConf("spark.sql.canonicalizeView", + defaultValue = Some(false), + doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + + "Note that this function is experimental and should ony be used when you are using " + + "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + + "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + + "possible, or you may get wrong result.", + isPublic = false) + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "") @@ -362,7 +371,7 @@ private[spark] object SQLConf { val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), - doc = "When true, automtically discover data partitions.") + doc = "When true, automatically discover data partitions.") val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", @@ -372,7 +381,7 @@ private[spark] object SQLConf { val PARTITION_MAX_FILES = intConf("spark.sql.sources.maxConcurrentWrites", defaultValue = Some(5), - doc = "The maximum number of concurent files to open before falling back on sorting when " + + doc = "The maximum number of concurrent files to open before falling back on sorting when " + "writing out files using dynamic partitioning.") // The output committer class used by HadoopFsRelation. The specified class needs to be a @@ -471,6 +480,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + private[spark] def canonicalizeView: Boolean = getConf(CANONICALIZE_VIEW) + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ea1521a48c8a7..cf59bc0d590b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ +import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -588,6 +589,28 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p + + case CreateViewAsSelect(table, child, allowExisting, replace, sql) => + if (conf.canonicalizeView) { + if (allowExisting && replace) { + throw new AnalysisException( + "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + } + + val (dbName, tblName) = processDatabaseAndTableName( + table.specifiedDatabase.getOrElse(client.currentDatabase), table.name) + + execution.CreateViewAsSelect( + table.copy( + specifiedDatabase = Some(dbName), + name = tblName), + child.output, + allowExisting, + replace) + } else { + HiveNativeCommand(sql) + } + case p @ CreateTableAsSelect(table, child, allowExisting) => val schema = if (table.schema.nonEmpty) { table.schema diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 256440a9a2e97..2bf22f5449641 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -77,6 +77,16 @@ private[hive] case class CreateTableAsSelect( childrenResolved } +private[hive] case class CreateViewAsSelect( + tableDesc: HiveTable, + child: LogicalPlan, + allowExisting: Boolean, + replace: Boolean, + sql: String) extends UnaryNode with Command { + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = false +} + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( @@ -99,7 +109,6 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ALTERVIEW", "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", @@ -110,7 +119,6 @@ private[hive] object HiveQl extends Logging { "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", "TOK_CREATEROLE", - "TOK_CREATEVIEW", "TOK_DESCDATABASE", "TOK_DESCFUNCTION", @@ -254,12 +262,17 @@ private[hive] object HiveQl extends Logging { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(SessionState.get().getConf()) - val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) + val hContext = createContext() + val node = getAst(sql, hContext) hContext.clear() node } + private def createContext(): Context = new Context(SessionState.get().getConf()) + + private def getAst(sql: String, context: Context) = + ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) + /** * Returns the HiveConf */ @@ -280,15 +293,18 @@ private[hive] object HiveQl extends Logging { /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String): LogicalPlan = { try { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { + val context = createContext() + val tree = getAst(sql, context) + val plan = if (nativeCommands contains tree.getText) { HiveNativeCommand(sql) } else { - nodeToPlan(tree) match { + nodeToPlan(tree, context) match { case NativePlaceholder => HiveNativeCommand(sql) case other => other } } + context.clear() + plan } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => pe.getMessage match { @@ -342,7 +358,9 @@ private[hive] object HiveQl extends Logging { } } - protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + protected def getClauses( + clauseNames: Seq[String], + nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { var remainingNodes = nodeList val clauses = clauseNames.map { clauseName => val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) @@ -489,7 +507,43 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } - protected def nodeToPlan(node: Node): LogicalPlan = node match { + private def createView( + view: ASTNode, + context: Context, + viewNameParts: ASTNode, + query: ASTNode, + schema: Seq[HiveColumn], + properties: Map[String, String], + allowExist: Boolean, + replace: Boolean): CreateViewAsSelect = { + val (db, viewName) = extractDbNameTableName(viewNameParts) + + val originalText = context.getTokenRewriteStream + .toString(query.getTokenStartIndex, query.getTokenStopIndex) + + val tableDesc = HiveTable( + specifiedDatabase = db, + name = viewName, + schema = schema, + partitionColumns = Seq.empty[HiveColumn], + properties = properties, + serdeProperties = Map[String, String](), + tableType = VirtualView, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = Some(originalText)) + + // We need to keep the original SQL string so that if `spark.sql.canonicalizeView` is + // false, we can fall back to use hive native command later. + // We can remove this when parser is configurable(can access SQLConf) in the future. + val sql = context.getTokenRewriteStream + .toString(view.getTokenStartIndex, view.getTokenStopIndex) + CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) + } + + protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { // Special drop table that also uncaches. case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: @@ -521,14 +575,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( - nodeToPlan(crtTbl), + nodeToPlan(crtTbl, context), extended = extended.isDefined) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( - nodeToPlan(query), + nodeToPlan(query, context), extended = extended.isDefined) case Token("TOK_DESCTABLE", describeArgs) => @@ -563,6 +617,73 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } } + case view @ Token("TOK_ALTERVIEW", children) => + val Some(viewNameParts) :: maybeQuery :: ignores = + getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME"), children) + + // if ALTER VIEW doesn't have query part, let hive to handle it. + maybeQuery.map { query => + createView(view, context, viewNameParts, query, Nil, Map(), false, true) + }.getOrElse(NativePlaceholder) + + case view @ Token("TOK_CREATEVIEW", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + val Seq( + Some(viewNameParts), + Some(query), + maybeComment, + replace, + allowExisting, + maybeProperties, + maybeColumns, + maybePartCols + ) = getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_ORREPLACE", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), children) + + // If the view is partitioned, we let hive handle it. + if (maybePartCols.isDefined) { + NativePlaceholder + } else { + val schema = maybeColumns.map { cols => + BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + // We can't specify column types when create view, so fill it with null first, and + // update it after the schema has been resolved later. + HiveColumn(field.getName, null, field.getComment) + } + }.getOrElse(Seq.empty[HiveColumn]) + + val properties = scala.collection.mutable.Map.empty[String, String] + + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) + } + + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + if (comment ne null) { + properties += ("comment" -> comment) + } + } + + createView(view, context, viewNameParts, query, schema, properties.toMap, + allowExisting.isDefined, replace.isDefined) + } + case Token("TOK_CREATETABLE", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -774,7 +895,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => // Unsupport features } - CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting != None) + CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) // If its not a "CTAS" like above then take it as a native command case Token("TOK_CREATETABLE", _) => NativePlaceholder @@ -793,7 +914,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C insertClauses.last match { case Token("TOK_CTE", cteClauses) => val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node).asInstanceOf[Subquery] + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] (relation.alias, relation) }).toMap (Some(args.head), insertClauses.init, Some(cteRelations)) @@ -847,7 +968,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val relations = fromClause match { - case Some(f) => nodeToRelation(f) + case Some(f) => nodeToRelation(f, context) case None => OneRowRelation } @@ -1094,7 +1215,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cteRelations.map(With(query, _)).getOrElse(query) // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT - case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + case Token("TOK_UNIONALL", left :: right :: Nil) => + Union(nodeToPlan(left, context), nodeToPlan(right, context)) case a: ASTNode => throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") @@ -1102,10 +1224,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val allJoinTokens = "(TOK_.*JOIN)".r val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node): LogicalPlan = node match { + def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", @@ -1121,7 +1243,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C outer = isOuter.nonEmpty, Some(alias.toLowerCase), attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause)) + nodeToRelation(relationClause, context)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -1189,7 +1311,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }.map(_._2) val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) + val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) @@ -1244,8 +1366,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - Join(nodeToRelation(relation1), - nodeToRelation(relation2), + Join(nodeToRelation(relation1, context), + nodeToRelation(relation2, context), joinType, other.headOption.map(nodeToExpr)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 3811c152a7ae6..915eae9d21e23 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.hive.client import java.io.PrintStream import java.util.{Map => JMap} +import javax.annotation.Nullable import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression -private[hive] case class HiveDatabase( - name: String, - location: String) +private[hive] case class HiveDatabase(name: String, location: String) private[hive] abstract class TableType { val name: String } private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } @@ -45,7 +44,7 @@ private[hive] case class HivePartition( values: Seq[String], storage: HiveStorageDescriptor) -private[hive] case class HiveColumn(name: String, hiveType: String, comment: String) +private[hive] case class HiveColumn(name: String, @Nullable hiveType: String, comment: String) private[hive] case class HiveTable( specifiedDatabase: Option[String], name: String, @@ -126,6 +125,12 @@ private[hive] trait ClientInterface { /** Returns the metadata for the specified table or None if it doens't exist. */ def getTableOption(dbName: String, tableName: String): Option[HiveTable] + /** Creates a view with the given metadata. */ + def createView(view: HiveTable): Unit + + /** Updates the given view with new metadata. */ + def alertView(view: HiveTable): Unit + /** Creates a table with the given metadata. */ def createTable(table: HiveTable): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 4d1e3ed9198e6..8f6d448b2aef4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -354,6 +354,37 @@ private[hive] class ClientWrapper( qlTable } + private def toViewTable(view: HiveTable): metadata.Table = { + // TODO: this is duplicated with `toQlTable` except the table type stuff. + val tbl = new metadata.Table(view.database, view.name) + tbl.setTableType(HTableType.VIRTUAL_VIEW) + tbl.setSerializationLib(null) + tbl.clearSerDeInfo() + + // TODO: we will save the same SQL string to original and expanded text, which is different + // from Hive. + tbl.setViewOriginalText(view.viewText.get) + tbl.setViewExpandedText(view.viewText.get) + + tbl.setFields(view.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + view.properties.foreach { case (k, v) => tbl.setProperty(k, v) } + + // set owner + tbl.setOwner(conf.getUser) + // set create time + tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) + + tbl + } + + override def createView(view: HiveTable): Unit = withHiveState { + client.createTable(toViewTable(view)) + } + + override def alertView(view: HiveTable): Unit = withHiveState { + client.alterTable(view.qualifiedName, toViewTable(view)) + } + override def createTable(table: HiveTable): Unit = withHiveState { val qlTable = toQlTable(table) client.createTable(qlTable) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala new file mode 100644 index 0000000000000..2b504ac974f07 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} + +/** + * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of + * depending on Hive meta-store. + */ +// TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is different +// from Hive and may not work for some cases like create view on self join. +private[hive] case class CreateViewAsSelect( + tableDesc: HiveTable, + childSchema: Seq[Attribute], + allowExisting: Boolean, + orReplace: Boolean) extends RunnableCommand { + + assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) + assert(tableDesc.viewText.isDefined) + + override def run(sqlContext: SQLContext): Seq[Row] = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + val database = tableDesc.database + val viewName = tableDesc.name + + if (hiveContext.catalog.tableExists(Seq(database, viewName))) { + if (allowExisting) { + // view already exists, will do nothing, to keep consistent with Hive + } else if (orReplace) { + hiveContext.catalog.client.alertView(prepareTable()) + } else { + throw new AnalysisException(s"View $database.$viewName already exists. " + + "If you want to update the view definition, please use ALTER VIEW AS or " + + "CREATE OR REPLACE VIEW AS") + } + } else { + hiveContext.catalog.client.createView(prepareTable()) + } + + Seq.empty[Row] + } + + private def prepareTable(): HiveTable = { + // setup column types according to the schema of child. + val schema = if (tableDesc.schema == Nil) { + childSchema.map { attr => + HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + } + } else { + childSchema.zip(tableDesc.schema).map { case (attr, col) => + HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + } + } + + val columnNames = childSchema.map(f => verbose(f.name)) + + // When user specified column names for view, we should create a project to do the renaming. + // When no column name specified, we still need to create a project to declare the columns + // we need, to make us more robust to top level `*`s. + val projectList = if (tableDesc.schema == Nil) { + columnNames.mkString(", ") + } else { + columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map { + case (name, alias) => s"$name AS $alias" + }.mkString(", ") + } + + val viewName = verbose(tableDesc.name) + + val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName" + + tableDesc.copy(schema = schema, viewText = Some(expandedText)) + } + + // escape backtick with double-backtick in column name and wrap it with backtick. + private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8c3f9ac202637..ec5b83b98e401 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1248,4 +1248,121 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), Row("b", 6.0) :: Row("a", 7.0) :: Nil) } } + + test("correctly parse CREATE VIEW statement") { + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt") { + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt") + sql( + """CREATE VIEW IF NOT EXISTS + |default.testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') + |COMMENT 'blabla' + |TBLPROPERTIES ('a' = 'b') + |AS SELECT * FROM jt""".stripMargin) + checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i))) + sql("DROP VIEW testView") + } + } + } + + test("correctly handle CREATE VIEW IF NOT EXISTS") { + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE VIEW IF NOT EXISTS testView AS SELECT * FROM jt2") + + // make sure our view doesn't change. + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + sql("DROP VIEW testView") + } + } + } + + test("correctly handle CREATE OR REPLACE VIEW") { + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + } + assert(e.message.contains("not allowed to define a view")) + } + } + } + + test("correctly handle ALTER VIEW") { + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("ALTER VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + } + } + } + + test("create hive view for json table") { + // json table is not hive-compatible, make sure the new flag fix it. + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + sql("DROP VIEW testView") + } + } + } + + test("create hive view for partitioned parquet table") { + // partitioned parquet table is not hive-compatible, make sure the new flag fix it. + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("parTable") { + val df = Seq(1 -> "a").toDF("i", "j") + df.write.format("parquet").partitionBy("i").saveAsTable("parTable") + sql("CREATE VIEW testView AS SELECT i, j FROM parTable") + checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) + sql("DROP VIEW testView") + } + } + } + + test("create hive view for joined tables") { + // make sure the new flag can handle some complex cases like join and schema change. + withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withTable("jt1", "jt2") { + sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") + sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") + sql("CREATE VIEW testView AS SELECT * FROM jt1 JOIN jt2 ON id1 == id2") + checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) + + val df = (1 until 10).map(i => i -> i).toDF("id1", "newCol") + df.write.format("json").mode(SaveMode.Overwrite).saveAsTable("jt1") + checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + } + } + } } From a8226a9f14e81c0b6712a30f1a60276200faebac Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 8 Oct 2015 13:49:10 -0700 Subject: [PATCH 002/139] Revert [SPARK-8654] [SQL] Fix Analysis exception when using NULL IN This reverts commit dcbd58a929be0058b1cfa59b14898c4c428a7680 from #8983 Author: Michael Armbrust Closes #9034 from marmbrus/revert8654. --- .../catalyst/analysis/HiveTypeCoercion.scala | 10 ++------- .../sql/catalyst/analysis/AnalysisSuite.scala | 21 ------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 7192c931d2e51..87a3845b2d9e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -304,10 +304,7 @@ object HiveTypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * an Analysis Exception is raised. + * Convert all expressions in in() list to the left operator type */ object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { @@ -315,10 +312,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { - case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) - case None => i - } + i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77a4765e7751c..820b336aac759 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -135,25 +135,4 @@ class AnalysisSuite extends AnalysisTest { plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) checkAnalysis(plan, plan) } - - test("SPARK-8654: invalid CAST in NULL IN(...) expression") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, - LocalRelation() - ) - assertAnalysisSuccess(plan) - } - - test("SPARK-8654: different types in inlist but can be converted to a commmon type") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, - LocalRelation() - ) - assertAnalysisSuccess(plan) - } - - test("SPARK-8654: check type compatibility error") { - val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, - LocalRelation() - ) - assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) - } } From 9e66a53c9955285a85c19f55c3ef62db2e1b868a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 8 Oct 2015 14:28:14 -0700 Subject: [PATCH 003/139] [SPARK-10993] [SQL] Inital code generated encoder for product types This PR is a first cut at code generating an encoder that takes a Scala `Product` type and converts it directly into the tungsten binary format. This is done through the addition of a new set of expression that can be used to invoke methods on raw JVM objects, extracting fields and converting the result into the required format. These can then be used directly in an `UnsafeProjection` allowing us to leverage the existing encoding logic. According to some simple benchmarks, this can significantly speed up conversion (~4x). However, replacing CatalystConverters is deferred to a later PR to keep this PR at a reasonable size. ```scala case class SomeInts(a: Int, b: Int, c: Int, d: Int, e: Int) val data = SomeInts(1, 2, 3, 4, 5) val encoder = ProductEncoder[SomeInts] val converter = CatalystTypeConverters.createToCatalystConverter(ScalaReflection.schemaFor[SomeInts].dataType) (1 to 5).foreach {iter => benchmark(s"converter $iter") { var i = 100000000 while (i > 0) { val res = converter(data).asInstanceOf[InternalRow] assert(res.getInt(0) == 1) assert(res.getInt(1) == 2) i -= 1 } } benchmark(s"encoder $iter") { var i = 100000000 while (i > 0) { val res = encoder.toRow(data) assert(res.getInt(0) == 1) assert(res.getInt(1) == 2) i -= 1 } } } ``` Results: ``` [info] converter 1: 7170ms [info] encoder 1: 1888ms [info] converter 2: 6763ms [info] encoder 2: 1824ms [info] converter 3: 6912ms [info] encoder 3: 1802ms [info] converter 4: 7131ms [info] encoder 4: 1798ms [info] converter 5: 7350ms [info] encoder 5: 1912ms ``` Author: Michael Armbrust Closes #9019 from marmbrus/productEncoder. --- .../spark/sql/catalyst/ScalaReflection.scala | 238 ++++++++++++- .../spark/sql/catalyst/encoders/Encoder.scala | 44 +++ .../catalyst/encoders/ProductEncoder.scala | 67 ++++ .../expressions/codegen/CodeGenerator.scala | 4 +- .../sql/catalyst/expressions/objects.scala | 334 ++++++++++++++++++ .../spark/sql/types/GenericArrayData.scala | 9 + .../apache/spark/sql/types/ObjectType.scala | 42 +++ .../encoders/ProductEncoderSuite.scala | 174 +++++++++ 8 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2442341da106d..8b733f2a0b91f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ @@ -75,6 +76,242 @@ trait ScalaReflection { */ private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + /** + * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping + * to a native type, an ObjectType is returned. Special handling is also used for Arrays including + * those that hold primitive types. + */ + def dataTypeFor(tpe: `Type`): DataType = tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(arrayType)) = tpe + val cls = arrayType match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case other => + // There is probably a better way to do this, but I couldn't find it... + val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls + java.lang.reflect.Array.newInstance(elementType, 1).getClass + + } + ObjectType(cls) + case other => ObjectType(Utils.classForName(className)) + } + } + + /** Returns expressions for extracting all the fields from the given type. */ + def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { + ScalaReflectionLock.synchronized { + extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children + } + } + + /** Helper for extracting internal fields from a case class. */ + protected def extractorFor( + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className: String = optType.erasure.typeSymbol.asClass.fullName + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, schemaFor(optType).dataType), + extractorFor(unwrapped, optType)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateStruct(params.head.map { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + extractorFor(fieldValue, fieldType) + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + if (!elementDataType.isInstanceOf[AtomicType]) { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) + } else { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + if (!elementDataType.isInstanceOf[AtomicType]) { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) + } else { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + + val rawMap = inputObject + val keys = + NewInstance( + classOf[GenericArrayData], + Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, + dataType = ObjectType(classOf[ArrayData])) + val values = + NewInstance( + classOf[GenericArrayData], + Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, + dataType = ObjectType(classOf[ArrayData])) + NewInstance( + classOf[ArrayBasedMapData], + keys :: values :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + } + } + } + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { val className: String = tpe.erasure.typeSymbol.asClass.fullName @@ -91,7 +328,6 @@ trait ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) - // Need to decide if we actually need a special type here. case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala new file mode 100644 index 0000000000000..8dacfa9477ee6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +/** + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking + * and reuse internal buffers to improve performance. + */ +trait Encoder[T] { + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] + + /** + * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to + * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should + * copy the result before making another call if required. + */ + def toRow(t: T): InternalRow +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala new file mode 100644 index 0000000000000..a23613673ebb5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} +import org.apache.spark.sql.types.{ObjectType, StructType} + +/** + * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL + * internal binary representation. + */ +object ProductEncoder { + def apply[T <: Product : TypeTag]: Encoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(typeTag[T].tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) + new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls)) + } +} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param extractExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object. + * @param clsTag A classtag for `T`. + */ +case class ClassEncoder[T]( + schema: StructType, + extractExpressions: Seq[Expression], + clsTag: ClassTag[T]) + extends Encoder[T] { + + private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private val inputRow = new GenericMutableRow(1) + + override def toRow(t: T): InternalRow = { + inputRow(0) = t + extractProjection(inputRow) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2dd680454b4cf..a0fe5bd77e3aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -177,6 +177,8 @@ class CodeGenContext { case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName case _ => "Object" } @@ -395,7 +397,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. - evaluator.setDebuggingInformation(false, true, false) + evaluator.setDebuggingInformation(true, true, false) withLineNums }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala new file mode 100644 index 0000000000000..e1f960a6e605c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.language.existentials + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ + +/** + * Invokes a static function, returning the result. By default, any of the arguments being null + * will result in returning null instead of calling the function. + * + * @param staticObject The target of the static call. This can either be the object itself + * (methods defined on scala objects), or the class object + * (static methods defined in java). + * @param dataType The expected return type of the function call + * @param functionName The name of the method to call. + * @param arguments An optional list of expressions to pass as arguments to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. + */ +case class StaticInvoke( + staticObject: Any, + dataType: DataType, + functionName: String, + arguments: Seq[Expression] = Nil, + propagateNull: Boolean = true) extends Expression { + + val objectName = staticObject match { + case c: Class[_] => c.getName + case other => other.getClass.getName.stripSuffix("$") + } + override def nullable: Boolean = true + override def children: Seq[Expression] = Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + if (propagateNull) { + val objNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" + ${argGen.map(_.code).mkString("\n")} + + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + + if ($argsNonNull) { + ${ev.value} = $objectName.$functionName($argString); + $objNullCheck + } + """ + } else { + s""" + ${argGen.map(_.code).mkString("\n")} + + final boolean ${ev.isNull} = ${ev.value} == null; + $javaType ${ev.value} = $objectName.$functionName($argString); + """ + } + } +} + +/** + * Calls the specified function on an object, optionally passing arguments. If the `targetObject` + * expression evaluates to null then null will be returned. + * + * @param targetObject An expression that will return the object to call the method on. + * @param functionName The name of the method to call. + * @param dataType The expected return type of the function. + * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + */ +case class Invoke( + targetObject: Expression, + functionName: String, + dataType: DataType, + arguments: Seq[Expression] = Nil) extends Expression { + + override def nullable: Boolean = true + override def children: Seq[Expression] = targetObject :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val obj = targetObject.gen(ctx) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val objNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + s""" + ${obj.code} + ${argGen.map(_.code).mkString("\n")} + + boolean ${ev.isNull} = ${obj.value} == null; + $javaType ${ev.value} = + ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString); + $objNullCheck + """ + } +} + +/** + * Constructs a new instance of the given class, using the result of evaluating the specified + * expressions as arguments. + * + * @param cls The class to construct. + * @param arguments A list of expression to use as arguments to the constructor. + * @param propagateNull When true, if any of the arguments is null, then null will be returned + * instead of trying to construct the object. + * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you + * to manually specify the type when the object in question is a valid internal + * representation (i.e. ArrayData) instead of an object. + */ +case class NewInstance( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = true, + dataType: DataType) extends Expression { + private val className = cls.getName + + override def nullable: Boolean = propagateNull + + override def children: Seq[Expression] = arguments + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val argGen = arguments.map(_.gen(ctx)) + val argString = argGen.map(_.value).mkString(", ") + + if (propagateNull) { + val objNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" + ${argGen.map(_.code).mkString("\n")} + + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + + if ($argsNonNull) { + ${ev.value} = new $className($argString); + ${ev.isNull} = false; + } + """ + } else { + s""" + ${argGen.map(_.code).mkString("\n")} + + final boolean ${ev.isNull} = ${ev.value} == null; + $javaType ${ev.value} = new $className($argString); + """ + } + } +} + +/** + * Given an expression that returns on object of type `Option[_]`, this expression unwraps the + * option into the specified Spark SQL datatype. In the case of `None`, the nullbit is set instead. + * + * @param dataType The expected unwrapped option type. + * @param child An expression that returns an `Option` + */ +case class UnwrapOption( + dataType: DataType, + child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def nullable: Boolean = true + + override def children: Seq[Expression] = Nil + + override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val inputObject = child.gen(ctx) + + s""" + ${inputObject.code} + + boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); + $javaType ${ev.value} = + ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); + """ + } +} + +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException("Only calling gen() is supported.") + + override def children: Seq[Expression] = Nil + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = + GeneratedExpressionCode(code = "", value = value, isNull = isNull) + + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + +} + +/** + * Applies the given expression to every element of a collection of items, returning the result + * as an ArrayType. This is similar to a typical map operation, but where the lambda function + * is expressed using catalyst expressions. + * + * The following collection ObjectTypes are currently supported: Seq, Array + * + * @param function A function that returns an expression, given an attribute that can be used + * to access the current value. This is does as a lambda function so that + * a unique attribute reference can be provided for each expression (thus allowing + * us to nest multiple MapObject calls). + * @param inputData An expression that when evaluted returns a collection object. + * @param elementType The type of element in the collection, expressed as a DataType. + */ +case class MapObjects( + function: AttributeReference => Expression, + inputData: Expression, + elementType: DataType) extends Expression { + + private val loopAttribute = AttributeReference("loopVar", elementType)() + private val completeFunction = function(loopAttribute) + + private val (lengthFunction, itemAccessor) = inputData.dataType match { + case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) => + (".size()", (i: String) => s".apply($i)") + case ObjectType(cls) if cls.isArray => + (".length", (i: String) => s"[$i]") + } + + override def nullable: Boolean = true + + override def children: Seq[Expression] = completeFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ArrayType(completeFunction.dataType) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val elementJavaType = ctx.javaType(elementType) + val genInputData = inputData.gen(ctx) + + // Variables to hold the element that is currently being processed. + val loopValue = ctx.freshName("loopValue") + val loopIsNull = ctx.freshName("loopIsNull") + + val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) + val boundFunction = completeFunction transform { + case a: AttributeReference if a == loopAttribute => loopVariable + } + + val genFunction = boundFunction.gen(ctx) + val dataLength = ctx.freshName("dataLength") + val convertedArray = ctx.freshName("convertedArray") + val loopIndex = ctx.freshName("loopIndex") + + s""" + ${genInputData.code} + + boolean ${ev.isNull} = ${genInputData.value} == null; + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${ev.isNull}) { + Object[] $convertedArray = null; + int $dataLength = ${genInputData.value}$lengthFunction; + $convertedArray = new Object[$dataLength]; + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $elementJavaType $loopValue = + ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; + boolean $loopIsNull = $loopValue == null; + + ${genFunction.code} + + $convertedArray[$loopIndex] = ${genFunction.value}; + $loopIndex += 1; + } + + ${ev.isNull} = false; + ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index 459fcb6fc0acc..c3816033275d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -22,6 +22,15 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { + def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) + + // TODO: This is boxing. We should specialize. + def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) + override def copy(): ArrayData = new GenericArrayData(array.clone()) override def numElements(): Int = array.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala new file mode 100644 index 0000000000000..fca0b799eb809 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.language.existentials + +private[sql] object ObjectType extends AbstractDataType { + override private[sql] def defaultConcreteType: DataType = + throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + + // No casting or comparison is supported. + override private[sql] def acceptsType(other: DataType): Boolean = false + + override private[sql] def simpleString: String = "Object" +} + +/** + * Represents a JVM object that is passing through Spark SQL expression evaluation. Note this + * is only used internally while converting into the internal format and is not intended for use + * outside of the execution engine. + */ +private[sql] case class ObjectType(cls: Class[_]) extends DataType { + override def defaultSize: Int = + throw new UnsupportedOperationException("No size estimation available for objects.") + + def asNullable: DataType = this +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala new file mode 100644 index 0000000000000..99c993d3febc2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.sql.{Date, Timestamp} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.ScalaReflection._ +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst._ + + +case class RepeatedStruct(s: Seq[PrimitiveData]) + +case class NestedArray(a: Array[Array[Int]]) + +class ProductEncoderSuite extends SparkFunSuite { + + test("convert PrimitiveData to InternalRow") { + val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val encoder = ProductEncoder[PrimitiveData] + val convertedData = encoder.toRow(inputData) + + assert(convertedData.getInt(0) == 1) + assert(convertedData.getLong(1) == 1.toLong) + assert(convertedData.getDouble(2) == 1.toDouble) + assert(convertedData.getFloat(3) == 1.toFloat) + assert(convertedData.getShort(4) == 1.toShort) + assert(convertedData.getByte(5) == 1.toByte) + assert(convertedData.getBoolean(6) == true) + } + + test("convert Some[_] to InternalRow") { + val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(primitiveData)) + + val encoder = ProductEncoder[OptionalData] + val convertedData = encoder.toRow(inputData) + + assert(convertedData.getInt(0) == 2) + assert(convertedData.getLong(1) == 2.toLong) + assert(convertedData.getDouble(2) == 2.toDouble) + assert(convertedData.getFloat(3) == 2.toFloat) + assert(convertedData.getShort(4) == 2.toShort) + assert(convertedData.getByte(5) == 2.toByte) + assert(convertedData.getBoolean(6) == true) + + val nestedRow = convertedData.getStruct(7, 7) + assert(nestedRow.getInt(0) == 1) + assert(nestedRow.getLong(1) == 1.toLong) + assert(nestedRow.getDouble(2) == 1.toDouble) + assert(nestedRow.getFloat(3) == 1.toFloat) + assert(nestedRow.getShort(4) == 1.toShort) + assert(nestedRow.getByte(5) == 1.toByte) + assert(nestedRow.getBoolean(6) == true) + } + + test("convert None to InternalRow") { + val inputData = OptionalData(None, None, None, None, None, None, None, None) + val encoder = ProductEncoder[OptionalData] + val convertedData = encoder.toRow(inputData) + + assert(convertedData.isNullAt(0)) + assert(convertedData.isNullAt(1)) + assert(convertedData.isNullAt(2)) + assert(convertedData.isNullAt(3)) + assert(convertedData.isNullAt(4)) + assert(convertedData.isNullAt(5)) + assert(convertedData.isNullAt(6)) + assert(convertedData.isNullAt(7)) + } + + test("convert nullable but present data to InternalRow") { + val inputData = NullableData( + 1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0), + new Timestamp(0), Array[Byte](1, 2, 3)) + + val encoder = ProductEncoder[NullableData] + val convertedData = encoder.toRow(inputData) + + assert(convertedData.getInt(0) == 1) + assert(convertedData.getLong(1) == 1.toLong) + assert(convertedData.getDouble(2) == 1.toDouble) + assert(convertedData.getFloat(3) == 1.toFloat) + assert(convertedData.getShort(4) == 1.toShort) + assert(convertedData.getByte(5) == 1.toByte) + assert(convertedData.getBoolean(6) == true) + } + + test("convert nullable data to InternalRow") { + val inputData = + NullableData(null, null, null, null, null, null, null, null, null, null, null, null) + + val encoder = ProductEncoder[NullableData] + val convertedData = encoder.toRow(inputData) + + assert(convertedData.isNullAt(0)) + assert(convertedData.isNullAt(1)) + assert(convertedData.isNullAt(2)) + assert(convertedData.isNullAt(3)) + assert(convertedData.isNullAt(4)) + assert(convertedData.isNullAt(5)) + assert(convertedData.isNullAt(6)) + assert(convertedData.isNullAt(7)) + assert(convertedData.isNullAt(8)) + assert(convertedData.isNullAt(9)) + assert(convertedData.isNullAt(10)) + assert(convertedData.isNullAt(11)) + } + + test("convert repeated struct") { + val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil) + val encoder = ProductEncoder[RepeatedStruct] + + val converted = encoder.toRow(inputData) + val convertedStruct = converted.getArray(0).getStruct(0, 7) + assert(convertedStruct.getInt(0) == 1) + assert(convertedStruct.getLong(1) == 1.toLong) + assert(convertedStruct.getDouble(2) == 1.toDouble) + assert(convertedStruct.getFloat(3) == 1.toFloat) + assert(convertedStruct.getShort(4) == 1.toShort) + assert(convertedStruct.getByte(5) == 1.toByte) + assert(convertedStruct.getBoolean(6) == true) + } + + test("convert nested seq") { + val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1)))) + assert(convertedData.getArray(0).getArray(0).getInt(0) == 1) + + val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1))))) + assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1) + } + + test("convert nested array") { + val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1)))) + } + + test("convert complex") { + val inputData = ComplexData( + Seq(1, 2), + Array(1, 2), + 1 :: 2 :: Nil, + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> new java.lang.Long(2)), + PrimitiveData(1, 1, 1, 1, 1, 1, true), + Array(Array(1))) + + val encoder = ProductEncoder[ComplexData] + val convertedData = encoder.toRow(inputData) + + assert(!convertedData.isNullAt(0)) + val seq = convertedData.getArray(0) + assert(seq.numElements() == 2) + assert(seq.getInt(0) == 1) + assert(seq.getInt(1) == 2) + } +} From 2816c89b6a304cb0b5214e14ebbc320158e88260 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 8 Oct 2015 14:53:21 -0700 Subject: [PATCH 004/139] [SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic In `aggregate/utils.scala`, there is a substantial amount of duplication in the expression-rewriting logic. As a prerequisite to supporting imperative aggregate functions in `TungstenAggregate`, this patch refactors this file so that the same expression-rewriting logic is used for both `SortAggregate` and `TungstenAggregate`. In order to allow both operators to use the same rewriting logic, `TungstenAggregationIterator. generateResultProjection()` has been updated so that it first evaluates all declarative aggregate functions' `evaluateExpression`s and writes the results into a temporary buffer, and then uses this temporary buffer and the grouping expressions to evaluate the final resultExpressions. This matches the logic in SortAggregateIterator, where this two-pass approach is necessary in order to support imperative aggregates. If this change turns out to cause performance regressions, then we can look into re-implementing the single-pass evaluation in a cleaner way as part of a followup patch. Since the rewriting logic is now shared across both operators, this patch also extracts that logic and places it in `SparkStrategies`. This makes the rewriting logic a bit easier to follow, I think. Author: Josh Rosen Closes #9015 from JoshRosen/SPARK-10988. --- .../spark/sql/execution/SparkStrategies.scala | 67 +++-- .../aggregate/TungstenAggregate.scala | 4 + .../TungstenAggregationIterator.scala | 22 +- .../spark/sql/execution/aggregate/utils.scala | 244 +++++------------- .../TungstenAggregationIteratorSuite.scala | 2 +- 5 files changed, 143 insertions(+), 196 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d1bbf2e20fcf4..79bd1a41808de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // Extracts all distinct aggregate expressions from the resultExpressions. + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg } - }.toSet.toSeq + }.distinct // For those distinct aggregate expressions, we create a map from the // aggregate function to the corresponding attribute of the function. - val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction - val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> - (aggregateFunction -> attribtue) + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = @@ -220,6 +223,40 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "code path.") } + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression2(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val aggregateOperator = if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { if (functionsWithDistinct.nonEmpty) { @@ -227,26 +264,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 3cd22af30592c..7b3d072b2e067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -31,7 +31,9 @@ case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { @@ -77,7 +79,9 @@ case class TungstenAggregate( new TungstenAggregationIterator( groupingExpressions, nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, completeAggregateExpressions, + completeAggregateAttributes, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index a6f4c1d92f6dc..4bb95c9eb7f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType * @param nonCompleteAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. + * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs + * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -280,17 +286,25 @@ class TungstenAggregationIterator( // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => val joinedRow = new JoinedRow() + val evalExpressions = allAggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + // case agg: AggregateFunction2 => Literal.create(null, agg.dataType) + } + val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes) + // These are the attributes of the row produced by `expressionAggEvalProjection` + val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + // Generate results for all expression-based aggregate functions. + val aggregateResult = expressionAggEvalProjection.apply(currentBuffer) + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } // Grouping-only: a output row is generated from values of grouping expressions. case (None, None) => - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes) + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { resultProjection(currentGroupingKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index e1c2d9475a10f..cf6e7ed0d337f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} @@ -38,60 +36,35 @@ object Utils { } def planAggregateWithoutPartial( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - + val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = - completeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val completeAggregateAttributes = completeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingExpressions.map(_._2), + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = Nil, nonCompleteAggregateAttributes = Nil, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = rewrittenResultExpressions, + resultExpressions = resultExpressions, child = child ) :: Nil } def planAggregateWithoutDistinct( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -104,36 +77,29 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = - namedGroupingAttributes ++ + groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialResultExpressions, child = child) } else { SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, @@ -145,58 +111,32 @@ object Utils { // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } val finalAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - // aggregateFunctionMap contains unique aggregate functions. - val aggregateFunction = - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = Nil, + resultExpressions = resultExpressions, child = partialAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = namedGroupingAttributes.length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, child = partialAggregate) } @@ -204,10 +144,10 @@ object Utils { } def planAggregateWithOneDistinct( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -221,20 +161,7 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val groupingAttributes = groupingExpressions.map(_.toAttribute) // It is safe to call head at here since functionsWithDistinct has at least one // AggregateExpression2. @@ -253,22 +180,27 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialAggregateGroupingExpressions = - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) + groupingExpressions ++ namedDistinctColumnExpressions.map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialAggregateResult, child = child) } else { SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, @@ -284,41 +216,40 @@ object Utils { val partialMergeAggregateAttributes = partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialMergeAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } // 3. Create an Aggregate Operator for partial merge aggregations. val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - // Create a map to store those rewritten aggregate functions. We always need to use - // both function and its corresponding isDistinct flag as the key because function itself - // does not knows if it is has distinct keyword or now. - val rewrittenAggregateFunctions = - mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children @@ -328,9 +259,6 @@ object Utils { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] - // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions - // to track the old version and the new version of this function. - rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -338,66 +266,30 @@ object Utils { val rewrittenAggregateExpression = AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = - aggregateFunctionMap(agg.aggregateFunction, true)._2 - (rewrittenAggregateExpression -> aggregateFunctionAttribute) + val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) + (rewrittenAggregateExpression, aggregateFunctionAttribute) }.unzip val finalAndCompleteAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - val function = agg.aggregateFunction - val isDistinct = agg.isDistinct - val aggregateFunction = - if (rewrittenAggregateFunctions.contains(function, isDistinct)) { - // If this function has been rewritten, we get the rewritten version from - // rewrittenAggregateFunctions. - rewrittenAggregateFunctions(function, isDistinct) - } else { - // Oterwise, we get it from aggregateFunctionMap, which contains unique - // aggregate functions that have not been rewritten. - aggregateFunctionMap(function, isDistinct)._1 - } - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = completeAggregateAttributes, + resultExpressions = resultExpressions, child = partialMergeAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 7ca677a6c72ad..ed974b3a53d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte () => new InterpretedMutableProjection(expr, schema) } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") - iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) From 02149ff08eed3745086589a047adbce9a580389f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 8 Oct 2015 16:18:35 -0700 Subject: [PATCH 005/139] [SPARK-8848] [SQL] Refactors Parquet write path to follow parquet-format This PR refactors Parquet write path to follow parquet-format spec. It's a successor of PR #7679, but with less non-essential changes. Major changes include: 1. Replaces `RowWriteSupport` and `MutableRowWriteSupport` with `CatalystWriteSupport` - Writes Parquet data using standard layout defined in parquet-format Specifically, we are now writing ... - ... arrays and maps in standard 3-level structure with proper annotations and field names - ... decimals as `INT32` and `INT64` whenever possible, and taking `FIXED_LEN_BYTE_ARRAY` as the final fallback - Supports legacy mode which is compatible with Spark 1.4 and prior versions The legacy mode is by default off, and can be turned on by flipping SQL option `spark.sql.parquet.writeLegacyFormat` to `true`. - Eliminates per value data type dispatching costs via prebuilt composed writer functions 1. Cleans up the last pieces of old Parquet support code As pointed out by rxin previously, we probably want to rename all those `Catalyst*` Parquet classes to `Parquet*` for clarity. But I'd like to do this in a follow-up PR to minimize code review noises in this one. Author: Cheng Lian Closes #8988 from liancheng/spark-8848/standard-parquet-write-path. --- .../org/apache/spark/sql/types/Decimal.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 5 +- .../parquet/CatalystReadSupport.scala | 35 +- .../parquet/CatalystRowConverter.scala | 59 +-- .../parquet/CatalystSchemaConverter.scala | 46 +- .../parquet/CatalystWriteSupport.scala | 436 ++++++++++++++++++ .../DirectParquetOutputCommitter.scala | 2 +- .../parquet/ParquetConverter.scala | 39 -- .../datasources/parquet/ParquetFilters.scala | 36 -- .../datasources/parquet/ParquetRelation.scala | 42 +- .../parquet/ParquetTableSupport.scala | 321 ------------- .../parquet/ParquetTypesConverter.scala | 160 ------- .../spark/sql/UserDefinedTypeSuite.scala | 32 +- .../datasources/parquet/ParquetIOSuite.scala | 63 +-- .../parquet/ParquetQuerySuite.scala | 46 +- .../parquet/ParquetSchemaSuite.scala | 4 +- .../datasources/parquet/ParquetTest.scala | 44 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 28 +- 18 files changed, 709 insertions(+), 693 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 909b8e31f2458..c11dab35cdf6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -108,7 +108,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) - require(decimalVal.precision <= precision, "Overflowed precision") + require( + decimalVal.precision <= precision, + s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") this.longVal = 0L this._precision = precision this._scale = scale diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 8f0f8910b36ab..47397c4be3cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -292,10 +292,9 @@ private[spark] object SQLConf { val PARQUET_WRITE_LEGACY_FORMAT = booleanConf( key = "spark.sql.parquet.writeLegacyFormat", - defaultValue = Some(true), + defaultValue = Some(false), doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.", - isPublic = false) + "Spark SQL schema and vice versa.") val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( key = "spark.sql.parquet.output.committer.class", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 5325698034095..a958373eb769d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -95,7 +95,9 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with """.stripMargin } - new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) + new CatalystRecordMaterializer( + parquetRequestedSchema, + CatalystReadSupport.expandUDT(catalystRequestedSchema)) } } @@ -110,7 +112,10 @@ private[parquet] object CatalystReadSupport { */ def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) - Types.buildMessage().addFields(clippedParquetFields: _*).named("root") + Types + .buildMessage() + .addFields(clippedParquetFields: _*) + .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) } private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { @@ -271,4 +276,30 @@ private[parquet] object CatalystReadSupport { .getOrElse(toParquet.convertField(f)) } } + + def expandUDT(schema: StructType): StructType = { + def expand(dataType: DataType): DataType = { + dataType match { + case t: ArrayType => + t.copy(elementType = expand(t.elementType)) + + case t: MapType => + t.copy( + keyType = expand(t.keyType), + valueType = expand(t.valueType)) + + case t: StructType => + val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType))) + t.copy(fields = expandedFields) + + case t: UserDefinedType[_] => + t.sqlType + + case t => + t + } + } + + expand(schema).asInstanceOf[StructType] + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 050d3610a6413..247d35363b862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -27,7 +27,6 @@ import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE -import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} import org.apache.spark.Logging @@ -114,7 +113,8 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp * any "parent" container. * * @param parquetType Parquet schema of Parquet records - * @param catalystType Spark SQL schema that corresponds to the Parquet record type + * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined + * types should have been expanded. * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class CatalystRowConverter( @@ -133,6 +133,12 @@ private[parquet] class CatalystRowConverter( |${catalystType.prettyJson} """.stripMargin) + assert( + !catalystType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]), + s"""User-defined types in Catalyst schema should have already been expanded: + |${catalystType.prettyJson} + """.stripMargin) + logDebug( s"""Building row converter for the following schema: | @@ -268,13 +274,6 @@ private[parquet] class CatalystRowConverter( override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) - case t: UserDefinedType[_] => - val catalystTypeForUDT = t.sqlType - val nullable = parquetType.isRepetition(Repetition.OPTIONAL) - val field = StructField("udt", catalystTypeForUDT, nullable) - val parquetTypeForUDT = new CatalystSchemaConverter().convertField(field) - newConverter(parquetTypeForUDT, catalystTypeForUDT, updater) - case _ => throw new RuntimeException( s"Unable to create Parquet converter for data type ${catalystType.json}") @@ -340,30 +339,36 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { - // Constructs a `Decimal` with an unscaled `Long` value if possible. The underlying - // `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we are using - // `Binary.toByteBuffer.array()` to steal the underlying byte array without copying it. - val buffer = value.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) + // Constructs a `Decimal` with an unscaled `Long` value if possible. + val unscaled = binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } + + private def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 6904fc736c106..7f3394c20ed3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -121,7 +121,7 @@ private[parquet] class CatalystSchemaConverter( val precision = field.getDecimalMetadata.getPrecision val scale = field.getDecimalMetadata.getScale - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") @@ -155,7 +155,7 @@ private[parquet] class CatalystSchemaConverter( } case INT96 => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( assumeInt96IsTimestamp, "INT96 is not supported unless it's interpreted as timestamp. " + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") @@ -197,11 +197,11 @@ private[parquet] class CatalystSchemaConverter( // // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists case LIST => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( field.getFieldCount == 1, s"Invalid list type $field") val repeatedType = field.getType(0) - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") if (isElementType(repeatedType, field.getName)) { @@ -217,17 +217,17 @@ private[parquet] class CatalystSchemaConverter( // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 // scalastyle:on case MAP | MAP_KEY_VALUE => - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( field.getFieldCount == 1 && !field.getType(0).isPrimitive, s"Invalid map type: $field") val keyValueType = field.getType(0).asGroupType() - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, s"Invalid map type: $field") val keyType = keyValueType.getType(0) - CatalystSchemaConverter.analysisRequire( + CatalystSchemaConverter.checkConversionRequirement( keyType.isPrimitive, s"Map key type is expected to be a primitive type, but found: $keyType") @@ -299,7 +299,10 @@ private[parquet] class CatalystSchemaConverter( * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. */ def convert(catalystSchema: StructType): MessageType = { - Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + Types + .buildMessage() + .addFields(catalystSchema.map(convertField): _*) + .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) } /** @@ -347,10 +350,10 @@ private[parquet] class CatalystSchemaConverter( // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. // // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond - // timestamp in Impala for some historical reasons, it's not recommended to be used for any - // other types and will probably be deprecated in future Parquet format spec. That's the - // reason why Parquet format spec only defines `TIMESTAMP_MILLIS` and `TIMESTAMP_MICROS` which - // are both logical types annotating `INT64`. + // timestamp in Impala for some historical reasons. It's not recommended to be used for any + // other types and will probably be deprecated in some future version of parquet-format spec. + // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and + // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. // // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store @@ -361,7 +364,7 @@ private[parquet] class CatalystSchemaConverter( // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) // hasn't implemented `TIMESTAMP_MICROS` yet. // - // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. case TimestampType => Types.primitive(INT96, repetition).named(field.name) @@ -523,11 +526,12 @@ private[parquet] class CatalystSchemaConverter( } } - private[parquet] object CatalystSchemaConverter { + val SPARK_PARQUET_SCHEMA_NAME = "spark_schema" + def checkFieldName(name: String): Unit = { // ,;{}()\n\t= and space are special characters in Parquet schema - analysisRequire( + checkConversionRequirement( !name.matches(".*[ ,;{}()\n\t=].*"), s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. @@ -539,7 +543,7 @@ private[parquet] object CatalystSchemaConverter { schema } - def analysisRequire(f: => Boolean, message: String): Unit = { + def checkConversionRequirement(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message) } @@ -553,16 +557,8 @@ private[parquet] object CatalystSchemaConverter { numBytes } - private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) - // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - def minBytesForPrecision(precision : Int) : Int = { - if (precision < MIN_BYTES_FOR_PRECISION.length) { - MIN_BYTES_FOR_PRECISION(precision) - } else { - computeMinBytesForPrecision(precision) - } - } + val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala new file mode 100644 index 0000000000000..483363d2c1a21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.{ByteBuffer, ByteOrder} +import java.util + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.io.api.{Binary, RecordConsumer} + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, minBytesForPrecision} +import org.apache.spark.sql.types._ + +/** + * A Parquet [[WriteSupport]] implementation that writes Catalyst [[InternalRow]]s as Parquet + * messages. This class can write Parquet data in two modes: + * + * - Standard mode: Parquet data are written in standard format defined in parquet-format spec. + * - Legacy mode: Parquet data are written in legacy format compatible with Spark 1.4 and prior. + * + * This behavior can be controlled by SQL option `spark.sql.parquet.writeLegacyFormat`. The value + * of this option is propagated to this class by the `init()` method and its Hadoop configuration + * argument. + */ +private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] with Logging { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. + // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access + // data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // Schema of the `InternalRow`s to be written + private var schema: StructType = _ + + // `ValueWriter`s for all fields of the schema + private var rootFieldWriters: Seq[ValueWriter] = _ + + // The Parquet `RecordConsumer` to which all `InternalRow`s are written + private var recordConsumer: RecordConsumer = _ + + // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions + private var writeLegacyParquetFormat: Boolean = _ + + // Reusable byte array used to write timestamps as Parquet INT96 values + private val timestampBuffer = new Array[Byte](12) + + // Reusable byte array used to write decimal values + private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + + override def init(configuration: Configuration): WriteContext = { + val schemaString = configuration.get(CatalystWriteSupport.SPARK_ROW_SCHEMA) + this.schema = StructType.fromString(schemaString) + this.writeLegacyParquetFormat = { + // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation + assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) + configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean + } + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + + val messageType = new CatalystSchemaConverter(configuration).convert(schema) + val metadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + + logInfo( + s"""Initialized Parquet WriteSupport with Catalyst schema: + |${schema.prettyJson} + |and corresponding Parquet message type: + |$messageType + """.stripMargin) + + new WriteContext(messageType, metadata) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + + override def write(row: InternalRow): Unit = { + consumeMessage { + writeFields(row, schema, rootFieldWriters) + } + } + + private def writeFields( + row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + if (!row.isNullAt(i)) { + consumeField(schema(i).name, i) { + fieldWriters(i).apply(row, i) + } + } + i += 1 + } + } + + private def makeWriter(dataType: DataType): ValueWriter = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getShort(ordinal)) + + case IntegerType | DateType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addInteger(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addLong(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addFloat(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addDouble(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromByteArray(row.getUTF8String(ordinal).getBytes)) + + case TimestampType => + (row: SpecializedGetters, ordinal: Int) => { + // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it + // Currently we only support timestamps stored as INT96, which is compatible with Hive + // and Impala. However, INT96 is to be deprecated. We plan to support `TIMESTAMP_MICROS` + // defined in the parquet-format spec. But up until writing, the most recent parquet-mr + // version (1.8.1) hasn't implemented it yet. + + // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond + // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromByteArray(timestampBuffer)) + } + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addBinary(Binary.fromByteArray(row.getBinary(ordinal))) + + case DecimalType.Fixed(precision, scale) => + makeDecimalWriter(precision, scale) + + case t: StructType => + val fieldWriters = t.map(_.dataType).map(makeWriter) + (row: SpecializedGetters, ordinal: Int) => + consumeGroup { + writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) + } + + case t: ArrayType => makeArrayWriter(t) + + case t: MapType => makeMapWriter(t) + + case t: UserDefinedType[_] => makeWriter(t.sqlType) + + // TODO Adds IntervalType support + case _ => sys.error(s"Unsupported data type $dataType.") + } + } + + private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { + assert( + precision <= DecimalType.MAX_PRECISION, + s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") + + val numBytes = minBytesForPrecision(precision) + + val int32Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addInteger(unscaledLong.toInt) + } + + val int64Writer = + (row: SpecializedGetters, ordinal: Int) => { + val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong + recordConsumer.addLong(unscaledLong) + } + + val binaryWriterUsingUnscaledLong = + (row: SpecializedGetters, ordinal: Int) => { + // When the precision is low enough (<= 18) to squeeze the decimal value into a `Long`, we + // can build a fixed-length byte array with length `numBytes` using the unscaled `Long` + // value and the `decimalBuffer` for better performance. + val unscaled = row.getDecimal(ordinal, precision, scale).toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + + while (i < numBytes) { + decimalBuffer(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + + recordConsumer.addBinary(Binary.fromByteArray(decimalBuffer, 0, numBytes)) + } + + val binaryWriterUsingUnscaledBytes = + (row: SpecializedGetters, ordinal: Int) => { + val decimal = row.getDecimal(ordinal, precision, scale) + val bytes = decimal.toJavaBigDecimal.unscaledValue().toByteArray + val fixedLengthBytes = if (bytes.length == numBytes) { + // If the length of the underlying byte array of the unscaled `BigInteger` happens to be + // `numBytes`, just reuse it, so that we don't bother copying it to `decimalBuffer`. + bytes + } else { + // Otherwise, the length must be less than `numBytes`. In this case we copy contents of + // the underlying bytes with padding sign bytes to `decimalBuffer` to form the result + // fixed-length byte array. + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + + recordConsumer.addBinary(Binary.fromByteArray(fixedLengthBytes, 0, numBytes)) + } + + writeLegacyParquetFormat match { + // Standard mode, 1 <= precision <= 9, writes as INT32 + case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + + // Standard mode, 10 <= precision <= 18, writes as INT64 + case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + + // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY + case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + + // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY + case _ => binaryWriterUsingUnscaledBytes + } + } + + def makeArrayWriter(arrayType: ArrayType): ValueWriter = { + val elementWriter = makeWriter(arrayType.elementType) + + def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < array.numElements()) { + consumeGroup { + // Only creates the element field if the current array element is not null. + if (!array.isNullAt(i)) { + consumeField(elementFieldName, 0) { + elementWriter.apply(array, i) + } + } + } + i += 1 + } + } + } + } + } + + def twoLevelArrayWriter(repeatedFieldName: String): ValueWriter = + (row: SpecializedGetters, ordinal: Int) => { + val array = row.getArray(ordinal) + consumeGroup { + // Only creates the repeated field if the array is non-empty. + if (array.numElements() > 0) { + consumeField(repeatedFieldName, 0) { + var i = 0 + while (i < array.numElements()) { + elementWriter.apply(array, i) + i += 1 + } + } + } + } + } + + (writeLegacyParquetFormat, arrayType.containsNull) match { + case (legacyMode @ false, _) => + // Standard mode: + // + // group (LIST) { + // repeated group list { + // ^~~~ repeatedGroupName + // element; + // ^~~~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "list", elementFieldName = "element") + + case (legacyMode @ true, nullableElements @ true) => + // Legacy mode, with nullable elements: + // + // group (LIST) { + // optional group bag { + // ^~~ repeatedGroupName + // repeated array; + // ^~~~~ elementFieldName + // } + // } + threeLevelArrayWriter(repeatedGroupName = "bag", elementFieldName = "array") + + case (legacyMode @ true, nullableElements @ false) => + // Legacy mode, with non-nullable elements: + // + // group (LIST) { + // repeated array; + // ^~~~~ repeatedFieldName + // } + twoLevelArrayWriter(repeatedFieldName = "array") + } + } + + private def makeMapWriter(mapType: MapType): ValueWriter = { + val keyWriter = makeWriter(mapType.keyType) + val valueWriter = makeWriter(mapType.valueType) + val repeatedGroupName = if (writeLegacyParquetFormat) { + // Legacy mode: + // + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // ^~~ repeatedGroupName + // required key; + // value; + // } + // } + "map" + } else { + // Standard mode: + // + // group (MAP) { + // repeated group key_value { + // ^~~~~~~~~ repeatedGroupName + // required key; + // value; + // } + // } + "key_value" + } + + (row: SpecializedGetters, ordinal: Int) => { + val map = row.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + + consumeGroup { + // Only creates the repeated field if the map is non-empty. + if (map.numElements() > 0) { + consumeField(repeatedGroupName, 0) { + var i = 0 + while (i < map.numElements()) { + consumeGroup { + consumeField("key", 0) { + keyWriter.apply(keyArray, i) + } + + // Only creates the "value" field if the value if non-empty + if (!map.valueArray().isNullAt(i)) { + consumeField("value", 1) { + valueWriter.apply(valueArray, i) + } + } + } + i += 1 + } + } + } + } + } + } + + private def consumeMessage(f: => Unit): Unit = { + recordConsumer.startMessage() + f + recordConsumer.endMessage() + } + + private def consumeGroup(f: => Unit): Unit = { + recordConsumer.startGroup() + f + recordConsumer.endGroup() + } + + private def consumeField(field: String, index: Int)(f: => Unit): Unit = { + recordConsumer.startField(field, index) + f + recordConsumer.endField(field, index) + } +} + +private[parquet] object CatalystWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" + + def setSchema(schema: StructType, configuration: Configuration): Unit = { + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) + configuration.set(SPARK_ROW_SCHEMA, schema.json) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index de1fd0166ac5a..300e8677b312f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -39,7 +39,7 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO * * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are - * left * empty). + * left empty). */ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala deleted file mode 100644 index ccd7ebf319af9..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{MapData, ArrayData} - -// TODO Removes this while fixing SPARK-8848 -private[sql] object CatalystConverter { - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). - // Note that "array" for the array elements is chosen by ParquetAvro. - // Using a different value will result in Parquet silently dropping columns. - val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" - val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - - val MAP_KEY_SCHEMA_NAME = "key" - val MAP_VALUE_SCHEMA_NAME = "value" - val MAP_SCHEMA_NAME = "map" - - // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType = ArrayData - type StructScalaType = InternalRow - type MapScalaType = MapData -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index c6b3fe7900da8..78040d99fb0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -18,24 +18,17 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable -import java.nio.ByteBuffer -import com.google.common.io.BaseEncoding -import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.OriginalType import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ private[sql] object ParquetFilters { - val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { @@ -282,33 +275,4 @@ private[sql] object ParquetFilters { addMethod.setAccessible(true) addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { - if (filters.nonEmpty) { - val serialized: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(filters).array() - val encoded: String = BaseEncoding.base64().encode(serialized) - conf.set(PARQUET_FILTER_DATA, encoded) - } - } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = { - val data = conf.get(PARQUET_FILTER_DATA) - if (data != null) { - val decoded: Array[Byte] = BaseEncoding.base64().decode(data) - SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(decoded)) - } else { - Seq() - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 8a9c0e733a9a1..77d851ca486b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -218,8 +218,8 @@ private[sql] class ParquetRelation( } // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible - val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) - if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassName == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[DirectParquetOutputCommitter].getCanonicalName) } @@ -248,18 +248,22 @@ private[sql] class ParquetRelation( // bundled with `ParquetOutputFormat[Row]`. job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - // TODO There's no need to use two kinds of WriteSupport - // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and - // complex types. - val writeSupportClass = - if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } + ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) + CatalystWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) + // and `CatalystWriteSupport` (writing actual rows to Parquet files). + conf.set( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlContext.conf.isParquetBinaryAsString.toString) - ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) - RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + conf.set( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlContext.conf.isParquetINT96AsTimestamp.toString) + + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlContext.conf.writeLegacyParquetFormat.toString) // Sets compression scheme conf.set( @@ -287,7 +291,6 @@ private[sql] class ParquetRelation( val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value @@ -304,8 +307,7 @@ private[sql] class ParquetRelation( useMetadataCache, parquetFilterPushDown, assumeBinaryIsString, - assumeInt96IsTimestamp, - writeLegacyParquetFormat) _ + assumeInt96IsTimestamp) _ // Create the function to set input paths at the driver side. val setInputPaths = @@ -530,8 +532,7 @@ private[sql] object ParquetRelation extends Logging { useMetadataCache: Boolean, parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, - assumeInt96IsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean)(job: Job): Unit = { + assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) @@ -552,16 +553,15 @@ private[sql] object ParquetRelation extends Logging { }) conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, + CatalystWriteSupport.SPARK_ROW_SCHEMA, CatalystSchemaConverter.checkFieldNames(dataSchema).json) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) - // Sets flags for Parquet schema conversion + // Sets flags for `CatalystSchemaConverter` conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) - conf.setBoolean(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, writeLegacyParquetFormat) overrideMinSplitSize(parquetBlockSize, conf) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala deleted file mode 100644 index ed89aa27aa1f0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala +++ /dev/null @@ -1,321 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import java.math.BigInteger -import java.nio.{ByteBuffer, ByteOrder} -import java.util.{HashMap => JHashMap} - -import org.apache.hadoop.conf.Configuration -import org.apache.parquet.column.ParquetProperties -import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.parquet.hadoop.api.WriteSupport -import org.apache.parquet.io.api._ - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A `parquet.hadoop.api.WriteSupport` for Row objects. - */ -private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { - - private[parquet] var writer: RecordConsumer = null - private[parquet] var attributes: Array[Attribute] = null - - override def init(configuration: Configuration): WriteSupport.WriteContext = { - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - val metadata = new JHashMap[String, String]() - metadata.put(CatalystReadSupport.SPARK_METADATA_KEY, origAttributesStr) - - if (attributes == null) { - attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray - } - - log.debug(s"write support initialized for requested schema $attributes") - new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) - } - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - writer = recordConsumer - log.debug(s"preparing for write with schema $attributes") - } - - override def write(record: InternalRow): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.numFields) { - throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + - s"($attributesSize > ${record.numFields})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (!record.isNullAt(index)) { - writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private[parquet] def writeValue(schema: DataType, value: Any): Unit = { - if (value != null) { - schema match { - case t: UserDefinedType[_] => writeValue(t.sqlType, value) - case t @ ArrayType(_, _) => writeArray( - t, - value.asInstanceOf[CatalystConverter.ArrayScalaType]) - case t @ MapType(_, _, _) => writeMap( - t, - value.asInstanceOf[CatalystConverter.MapScalaType]) - case t @ StructType(_) => writeStruct( - t, - value.asInstanceOf[CatalystConverter.StructScalaType]) - case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) - } - } - } - - private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { - if (value != null) { - schema match { - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case ShortType => writer.addInteger(value.asInstanceOf[Short]) - case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) - case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[Long]) - case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) - case StringType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case DecimalType.Fixed(precision, _) => - writeDecimal(value.asInstanceOf[Decimal], precision) - case _ => sys.error(s"Do not know how to writer $schema to consumer") - } - } - } - - private[parquet] def writeStruct( - schema: StructType, - struct: CatalystConverter.StructScalaType): Unit = { - if (struct != null) { - val fields = schema.fields.toArray - writer.startGroup() - var i = 0 - while(i < fields.length) { - if (!struct.isNullAt(i)) { - writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) - writer.endField(fields(i).name, i) - } - i = i + 1 - } - writer.endGroup() - } - } - - private[parquet] def writeArray( - schema: ArrayType, - array: CatalystConverter.ArrayScalaType): Unit = { - val elementType = schema.elementType - writer.startGroup() - if (array.numElements() > 0) { - if (schema.containsNull) { - writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - var i = 0 - while (i < array.numElements()) { - writer.startGroup() - if (!array.isNullAt(i)) { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array.get(i, elementType)) - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - writer.endGroup() - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - } else { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - var i = 0 - while (i < array.numElements()) { - writeValue(elementType, array.get(i, elementType)) - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - } - writer.endGroup() - } - - private[parquet] def writeMap( - schema: MapType, - map: CatalystConverter.MapScalaType): Unit = { - writer.startGroup() - val length = map.numElements() - if (length > 0) { - writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) - map.foreach(schema.keyType, schema.valueType, (key, value) => { - writer.startGroup() - writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - writeValue(schema.keyType, key) - writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - if (value != null) { - writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - writeValue(schema.valueType, value) - writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - } - writer.endGroup() - }) - writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) - } - writer.endGroup() - } - - // Scratch array used to write decimals as fixed-length byte array - private[this] var reusableDecimalBytes = new Array[Byte](16) - - private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) - - def longToBinary(unscaled: Long): Binary = { - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - reusableDecimalBytes(i) = (unscaled >> shift).toByte - i += 1 - shift -= 8 - } - Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) - } - - def bigIntegerToBinary(unscaled: BigInteger): Binary = { - unscaled.toByteArray match { - case bytes if bytes.length == numBytes => - Binary.fromByteArray(bytes) - - case bytes if bytes.length <= reusableDecimalBytes.length => - val signedByte = (if (bytes.head < 0) -1 else 0).toByte - java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) - System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) - Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) - - case bytes => - reusableDecimalBytes = new Array[Byte](bytes.length) - bigIntegerToBinary(unscaled) - } - } - - val binary = if (numBytes <= 8) { - longToBinary(decimal.toUnscaledLong) - } else { - bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) - } - - writer.addBinary(binary) - } - - // array used to write Timestamp as Int96 (fixed-length binary) - private[this] val int96buf = new Array[Byte](12) - - private[parquet] def writeTimestamp(ts: Long): Unit = { - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts) - val buf = ByteBuffer.wrap(int96buf) - buf.order(ByteOrder.LITTLE_ENDIAN) - buf.putLong(timeOfDayNanos) - buf.putInt(julianDay) - writer.addBinary(Binary.fromByteArray(int96buf)) - } -} - -// Optimized for non-nested rows -private[parquet] class MutableRowWriteSupport extends RowWriteSupport { - override def write(record: InternalRow): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.numFields) { - throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + - s"($attributesSize > ${record.numFields})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (!record.isNullAt(index) && !record.isNullAt(index)) { - writer.startField(attributes(index).name, index) - consumeType(attributes(index).dataType, record, index) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private def consumeType( - ctype: DataType, - record: InternalRow, - index: Int): Unit = { - ctype match { - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case IntegerType | DateType => writer.addInteger(record.getInt(index)) - case LongType => writer.addLong(record.getLong(index)) - case TimestampType => writeTimestamp(record.getLong(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case StringType => - writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) - case BinaryType => - writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, scale) => - writeDecimal(record.getDecimal(index, precision, scale), precision) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } -} - -private[parquet] object RowWriteSupport { - val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - - def getSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (schemaString == null) { - throw new RuntimeException("Missing schema!") - } - ParquetTypesConverter.convertFromString(schemaString) - } - - def setSchema(schema: Seq[Attribute], configuration: Configuration) { - val encoded = ParquetTypesConverter.convertToString(schema) - configuration.set(SPARK_ROW_SCHEMA, encoded) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala deleted file mode 100644 index b647bb6116afa..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.IOException -import java.util.{Collections, Arrays} - -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.MessageType - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types._ - - -private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true - case _ => false - } - - /** - * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. - */ - private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => - var length = 1 - while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { - length += 1 - } - length - } - - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val converter = new CatalystSchemaConverter() - converter.convert(StructType.fromAttributes(attributes)) - } - - def convertFromString(string: String): Seq[Attribute] = { - Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { - case s: StructType => s.toAttributes - case other => sys.error(s"Can convert $string to row") - } - } - - def convertToString(schema: Seq[Attribute]): String = { - schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) - StructType.fromAttributes(schema).json - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put( - CatalystReadSupport.SPARK_METADATA_KEY, - ParquetTypesConverter.convertToString(attributes)) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetFileWriter.writeMetadataFile( - conf, - path, - Arrays.asList(new Footer(path, new ParquetMetadata(metaData, Collections.emptyList())))) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @param configuration The Hadoop configuration to use. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - val conf = { - // scalastyle:off jobcontext - configuration.getOrElse(ContextUtil.getConfiguration(job)) - // scalastyle:on jobcontext - } - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - - val children = - fs - .globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .filterNot { status => - val name = status.getPath.getName - (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE - } - - // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row - // groups. Since Parquet schema is replicated among all row groups, we only need to touch a - // single row group to read schema related metadata. Notice that we are making assumptions that - // all data in a single Parquet file have the same schema, which is normally true. - children - // Try any non-"_metadata" file first... - .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) - // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is - // empty, thus normally the "_metadata" file is expected to be fairly small). - .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) - .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) - .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 7992fd59ff4ba..d17671d48a2fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -68,7 +69,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ private lazy val pointsRDD = Seq( @@ -98,17 +99,28 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true))) } - - test("UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } - test("Repartition UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("Repartition UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.repartition(1).write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } // Tests to make sure that all operators correctly convert types on the way out. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index cd552e83372f1..599cf948e76a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -28,10 +28,10 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} @@ -99,16 +99,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } - test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .toDF() - // Parquet doesn't allow column names with spaces, have to add an alias here - .select($"_1" cast decimal as "dec") + testStandardAndLegacyModes("fixed-length decimals") { + def makeDecimalRDD(decimal: DecimalType): DataFrame = { + sqlContext + .range(1000) + // Parquet doesn't allow column names with spaces, have to add an alias here. + // Minus 500 here so that negative decimals are also tested. + .select((('id - 500) / 100.0) cast decimal as 'dec) + .coalesce(1) + } - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { + val combinations = Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37)) + for ((precision, scale) <- combinations) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) @@ -132,22 +134,22 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("map") { + testStandardAndLegacyModes("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) } - test("array") { + testStandardAndLegacyModes("array") { val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) checkParquetFile(data) } - test("array and double") { + testStandardAndLegacyModes("array and double") { val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) checkParquetFile(data) } - test("struct") { + testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -157,7 +159,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("nested struct with array of array as field") { + testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -167,7 +169,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("nested map with struct as value type") { + testStandardAndLegacyModes("nested map with struct as value type") { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetDataFrame(data) { df => checkAnswer(df, data.map { case Tuple1(m) => @@ -205,14 +207,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("compression codec") { - def compressionCodecFor(path: String): String = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(hadoopConfiguration)).getBlocks.asScala - .flatMap(_.getColumns.asScala) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) + def compressionCodecFor(path: String, codecName: String): String = { + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConfiguration) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + + assert(codecs.distinct === Seq(codecName)) codecs.head } @@ -222,7 +224,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) + compressionCodecFor(path, codec.name()) } } } @@ -278,15 +280,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { file => val path = new Path(file.toURI.toString) val fs = FileSystem.getLocal(hadoopConfiguration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, hadoopConfiguration) + val schema = StructType.fromAttributes(ScalaReflection.attributesFor[(Int, String)]) + writeMetadata(schema, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(hadoopConfiguration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val expectedSchema = new CatalystSchemaConverter().convert(schema) + val actualSchema = readFooter(path, hadoopConfiguration).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) expectedSchema.checkContains(actualSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 1c1cfa34ad04b..cc02ef81c9f8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -484,7 +484,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - test("SPARK-10301 requested schema clipping - UDT") { + testStandardAndLegacyModes("SPARK-10301 requested schema clipping - UDT") { withTempPath { dir => val path = dir.getCanonicalPath @@ -517,6 +517,50 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext Row(Row(NestedStruct(1, 2L, 3.5D)))) } } + + test("expand UDT in StructType") { + val schema = new StructType().add("n", new NestedStructUDT, nullable = true) + val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true) + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in ArrayType") { + val schema = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT, + containsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT().sqlType, + containsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in MapType") { + val schema = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT, + valueContainsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT().sqlType, + valueContainsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index f17fb36f25fe8..60fa81b1ab819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -357,8 +357,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on - val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) - val fromJson = ParquetTypesConverter.convertFromString(jsonString) + val fromCaseClassString = StructType.fromString(caseClassString) + val fromJson = StructType.fromString(jsonString) (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 442fafb12f200..9840ad919e510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,11 +19,19 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} + import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -97,4 +105,38 @@ private[sql] trait ParquetTest extends SQLTestUtils { assert(partDir.mkdirs(), s"Couldn't create directory $partDir") partDir } + + protected def writeMetadata( + schema: StructType, path: Path, configuration: Configuration): Unit = { + val parquetSchema = new CatalystSchemaConverter().convert(schema) + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schema.json).asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + + protected def readAllFootersWithoutSummaryFiles( + path: Path, configuration: Configuration): Seq[Footer] = { + val fs = path.getFileSystem(configuration) + ParquetFileReader.readAllFootersInParallel(configuration, fs.getFileStatus(path)).asScala.toSeq + } + + protected def readFooter(path: Path, configuration: Configuration): ParquetMetadata = { + ParquetFileReader.readFooter( + configuration, + new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE), + ParquetMetadataConverter.NO_FILTER) + } + + protected def testStandardAndLegacyModes(testName: String)(f: => Unit): Unit = { + test(s"Standard mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { f } + } + + test(s"Legacy mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { f } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 107457f79ec03..d63f3d3996523 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, QueryTest, Row, SaveMode} class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { import hiveContext.implicits._ @@ -74,11 +74,13 @@ class DataSourceWithHiveMetastoreCatalogSuite ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { withTable("t") { - testDF - .write - .mode(SaveMode.Overwrite) - .format(provider) - .saveAsTable("t") + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + } val hiveTable = catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) @@ -102,12 +104,14 @@ class DataSourceWithHiveMetastoreCatalogSuite withTable("t") { val path = dir.getCanonicalFile - testDF - .write - .mode(SaveMode.Overwrite) - .format(provider) - .option("path", path.toString) - .saveAsTable("t") + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + } val hiveTable = catalog.client.getTable("default", "t") assert(hiveTable.inputFormat === Some(inputFormat)) From 84ea287178247c163226e835490c9c70b17d8d3b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 8 Oct 2015 17:25:14 -0700 Subject: [PATCH 006/139] [SPARK-10914] UnsafeRow serialization breaks when two machines have different Oops size. UnsafeRow contains 3 pieces of information when pointing to some data in memory (an object, a base offset, and length). When the row is serialized with Java/Kryo serialization, the object layout in memory can change if two machines have different pointer width (Oops in JVM). To reproduce, launch Spark using MASTER=local-cluster[2,1,1024] bin/spark-shell --conf "spark.executor.extraJavaOptions=-XX:-UseCompressedOops" And then run the following scala> sql("select 1 xx").collect() Author: Reynold Xin Closes #9030 from rxin/SPARK-10914. --- .../sql/catalyst/expressions/UnsafeRow.java | 47 +++++++++++++++++-- .../org/apache/spark/sql/UnsafeRowSuite.scala | 29 +++++++++++- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e8ac2999c2d29..5af7ed5d6eb6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.IOException; -import java.io.OutputStream; +import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; import java.util.Arrays; @@ -26,6 +25,11 @@ import java.util.HashSet; import java.util.Set; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -35,6 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -52,7 +57,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends MutableRow { +public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable { ////////////////////////////////////////////////////////////////////////////// // Static methods @@ -596,4 +601,40 @@ public boolean anyNull() { public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.writeInt(this.numFields); + out.write(bytes); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.numFields = in.readInt(); + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = new byte[sizeInBytes]; + in.readFully((byte[]) baseObject); + } + + @Override + public void write(Kryo kryo, Output out) { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.writeInt(this.numFields); + out.write(bytes); + } + + @Override + public void read(Kryo kryo, Input in) { + this.baseOffset = BYTE_ARRAY_OFFSET; + this.sizeInBytes = in.readInt(); + this.numFields = in.readInt(); + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = new byte[sizeInBytes]; + in.read((byte[]) baseObject); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 944d4e11348cf..7d1ee39d4b539 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{KryoSerializer, JavaSerializer} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.types._ @@ -29,6 +30,32 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { + test("UnsafeRow Java serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new JavaSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + + test("UnsafeRow Kryo serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new KryoSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + test("bitset width calculation") { assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) From 3390b400d04e40f767d8a51f1078fcccb4e64abd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Oct 2015 17:34:24 -0700 Subject: [PATCH 007/139] [SPARK-10810] [SPARK-10902] [SQL] Improve session management in SQL This PR improve the sessions management by replacing the thread-local based to one SQLContext per session approach, introduce separated temporary tables and UDFs/UDAFs for each session. A new session of SQLContext could be created by: 1) create an new SQLContext 2) call newSession() on existing SQLContext For HiveContext, in order to reduce the cost for each session, the classloader and Hive client are shared across multiple sessions (created by newSession). CacheManager is also shared by multiple sessions, so cache a table multiple times in different sessions will not cause multiple copies of in-memory cache. Added jars are still shared by all the sessions, because SparkContext does not support sessions. cc marmbrus yhuai rxin Author: Davies Liu Closes #8909 from davies/sessions. --- project/MimaExcludes.scala | 22 ++- .../catalyst/analysis/FunctionRegistry.scala | 28 ++- .../org/apache/spark/sql/SQLContext.scala | 164 ++++++++++-------- .../spark/sql/execution/CacheManager.scala | 14 +- .../apache/spark/sql/SQLContextSuite.scala | 59 ++++--- .../spark/sql/test/TestSQLContext.scala | 21 +-- .../SparkExecuteStatementOperation.scala | 76 ++------ .../thriftserver/SparkSQLSessionManager.scala | 9 +- .../server/SparkSQLOperationManager.scala | 5 +- .../sql/hive/thriftserver/CliSuite.scala | 8 +- .../HiveThriftServer2Suites.scala | 76 ++++---- .../apache/spark/sql/hive/HiveContext.scala | 155 +++++++++++------ .../org/apache/spark/sql/hive/HiveQl.scala | 28 +-- .../sql/hive/client/ClientInterface.scala | 9 + .../spark/sql/hive/client/ClientWrapper.scala | 85 +++++---- .../hive/client/IsolatedClientLoader.scala | 107 +++++++----- .../spark/sql/hive/execution/commands.scala | 27 +-- .../apache/spark/sql/hive/test/TestHive.scala | 27 +-- .../apache/spark/sql/hive/HiveQlSuite.scala | 13 +- .../spark/sql/hive/client/VersionsSuite.scala | 6 +- .../sql/hive/execution/HiveQuerySuite.scala | 32 ++++ .../sql/hive/execution/SQLQuerySuite.scala | 9 +- 22 files changed, 540 insertions(+), 440 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2d4d146f51339..08e4a449cf762 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -79,7 +79,27 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this") + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.SQLContext$SQLSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.detachSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.tlSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.defaultSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.currentSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.openSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.createSession") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e6122d92b763c..ba77b70a378a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -51,23 +51,37 @@ class SimpleFunctionRegistry extends FunctionRegistry { private val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = { + override def registerFunction( + name: String, + info: ExpressionInfo, + builder: FunctionBuilder): Unit = synchronized { functionBuilders.put(name, (info, builder)) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - val func = functionBuilders.get(name).map(_._2).getOrElse { - throw new AnalysisException(s"undefined function $name") + val func = synchronized { + functionBuilders.get(name).map(_._2).getOrElse { + throw new AnalysisException(s"undefined function $name") + } } func(children) } - override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted + override def listFunction(): Seq[String] = synchronized { + functionBuilders.iterator.map(_._1).toList.sorted + } - override def lookupFunction(name: String): Option[ExpressionInfo] = { + override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { functionBuilders.get(name).map(_._1) } + + def copy(): SimpleFunctionRegistry = synchronized { + val registry = new SimpleFunctionRegistry + functionBuilders.iterator.foreach { case (name, (info, builder)) => + registry.registerFunction(name, info, builder) + } + registry + } } /** @@ -257,7 +271,7 @@ object FunctionRegistry { expression[InputFileName]("input_file_name") ) - val builtin: FunctionRegistry = { + val builtin: SimpleFunctionRegistry = { val fr = new SimpleFunctionRegistry expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } fr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cb0a3e361c97a..2bdfd82af0adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,6 +30,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -38,15 +39,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} -import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.{execution => sparkexecution} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.sources._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils /** @@ -64,18 +62,30 @@ import org.apache.spark.util.Utils * * @since 1.0.0 */ -class SQLContext(@transient val sparkContext: SparkContext) - extends org.apache.spark.Logging - with Serializable { +class SQLContext private[sql]( + @transient val sparkContext: SparkContext, + @transient protected[sql] val cacheManager: CacheManager) + extends org.apache.spark.Logging with Serializable { self => + def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager) def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) + /** + * Returns a SQLContext as new session, with separated SQL configurations, temporary tables, + * registered functions, but sharing the same SparkContext and CacheManager. + * + * @since 1.6.0 + */ + def newSession(): SQLContext = { + new SQLContext(sparkContext, cacheManager) + } + /** * @return Spark SQL configuration */ - protected[sql] def conf = currentSession().conf + protected[sql] lazy val conf = new SQLConf // `listener` should be only used in the driver @transient private[sql] val listener = new SQLListener(this) @@ -142,13 +152,11 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - // TODO how to handle the temp table per user session? @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) - // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin + protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() @transient protected[sql] lazy val analyzer: Analyzer = @@ -198,20 +206,19 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def executePlan(plan: LogicalPlan) = new sparkexecution.QueryExecution(this, plan) - @transient - protected[sql] val tlSession = new ThreadLocal[SQLSession]() { - override def initialValue: SQLSession = defaultSession - } - - @transient - protected[sql] val defaultSession = createSession() - protected[sql] def dialectClassName = if (conf.dialect == "sql") { classOf[DefaultParserDialect].getCanonicalName } else { conf.dialect } + /** + * Add a jar to SQLContext + */ + protected[sql] def addJar(path: String): Unit = { + sparkContext.addJar(path) + } + { // We extract spark sql settings from SparkContext's conf and put them to // Spark SQL's conf. @@ -236,9 +243,6 @@ class SQLContext(@transient val sparkContext: SparkContext) } } - @transient - protected[sql] val cacheManager = new CacheManager(this) - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -300,21 +304,25 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group cachemgmt * @since 1.3.0 */ - def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + def isCached(tableName: String): Boolean = { + cacheManager.lookupCachedData(table(tableName)).nonEmpty + } /** * Caches the specified table in-memory. * @group cachemgmt * @since 1.3.0 */ - def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + def cacheTable(tableName: String): Unit = { + cacheManager.cacheQuery(table(tableName), Some(tableName)) + } /** * Removes the specified table from the in-memory cache. * @group cachemgmt * @since 1.3.0 */ - def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName)) /** * Removes all cached tables from the in-memory cache. @@ -830,36 +838,6 @@ class SQLContext(@transient val sparkContext: SparkContext) ) } - protected[sql] def openSession(): SQLSession = { - detachSession() - val session = createSession() - tlSession.set(session) - - session - } - - protected[sql] def currentSession(): SQLSession = { - tlSession.get() - } - - protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] def detachSession(): Unit = { - tlSession.remove() - } - - protected[sql] def setSession(session: SQLSession): Unit = { - detachSession() - tlSession.set(session) - } - - protected[sql] class SQLSession { - // Note that this is a lazy val so we can override the default value in subclasses. - protected[sql] lazy val conf: SQLConf = new SQLConf - } - @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") protected[sql] class QueryExecution(logical: LogicalPlan) extends sparkexecution.QueryExecution(this, logical) @@ -1196,46 +1174,90 @@ class SQLContext(@transient val sparkContext: SparkContext) // Register a succesfully instantiatd context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. - SQLContext.setLastInstantiatedContext(self) + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + SQLContext.clearInstantiatedContext(self) + } + }) + + SQLContext.setInstantiatedContext(self) } /** * This SQLContext object contains utility functions to create a singleton SQLContext instance, - * or to get the last created SQLContext instance. + * or to get the created SQLContext instance. + * + * It also provides utility functions to support preference for threads in multiple sessions + * scenario, setActive could set a SQLContext for current thread, which will be returned by + * getOrCreate instead of the global one. */ object SQLContext { - private val INSTANTIATION_LOCK = new Object() + /** + * The active SQLContext for the current thread. + */ + private val activeContext: InheritableThreadLocal[SQLContext] = + new InheritableThreadLocal[SQLContext] /** - * Reference to the last created SQLContext. + * Reference to the created SQLContext. */ - @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]() + @transient private val instantiatedContext = new AtomicReference[SQLContext]() /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. + * * This function can be used to create a singleton SQLContext object that can be shared across * the JVM. + * + * If there is an active SQLContext for current thread, it will be returned instead of the global + * one. + * + * @since 1.5.0 */ def getOrCreate(sparkContext: SparkContext): SQLContext = { - INSTANTIATION_LOCK.synchronized { - if (lastInstantiatedContext.get() == null) { + val ctx = activeContext.get() + if (ctx != null) { + return ctx + } + + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null) { new SQLContext(sparkContext) + } else { + ctx } } - lastInstantiatedContext.get() } - private[sql] def clearLastInstantiatedContext(): Unit = { - INSTANTIATION_LOCK.synchronized { - lastInstantiatedContext.set(null) - } + private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { + instantiatedContext.compareAndSet(sqlContext, null) } - private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = { - INSTANTIATION_LOCK.synchronized { - lastInstantiatedContext.set(sqlContext) - } + private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { + instantiatedContext.compareAndSet(null, sqlContext) + } + + /** + * Changes the SQLContext that will be returned in this thread and its children when + * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives + * a SQLContext with an isolated session, instead of the global (first created) context. + * + * @since 1.6.0 + */ + def setActive(sqlContext: SQLContext): Unit = { + activeContext.set(sqlContext) + } + + /** + * Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will + * return the first created context instead of a thread-local override. + * + * @since 1.6.0 + */ + def clearActive(): Unit = { + activeContext.remove() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d3e5c378d037d..f85aeb1b02694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -37,7 +37,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe * * Internal to Spark SQL. */ -private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { +private[sql] class CacheManager extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -45,15 +45,6 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { @transient private val cacheLock = new ReentrantReadWriteLock - /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty - - /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) - - /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) - /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { val lock = cacheLock.readLock() @@ -96,6 +87,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { + val sqlContext = query.sqlContext cachedData += CachedData( planToCache, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index dd88ae3700ab9..1994dacfc4dfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,33 +17,52 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -class SQLContextSuite extends SparkFunSuite with SharedSQLContext { - - override def afterAll(): Unit = { - try { - SQLContext.setLastInstantiatedContext(sqlContext) - } finally { - super.afterAll() - } - } +class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("getOrCreate instantiates SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(sparkContext) + val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } - test("getOrCreate gets last explicitly instantiated SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(sparkContext) - assert(SQLContext.getOrCreate(sparkContext) != null, - "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), + test("getOrCreate return the original SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val newSession = sqlContext.newSession() + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + SQLContext.setActive(newSession) + assert(SQLContext.getOrCreate(sc).eq(newSession), + "SQLContext.getOrCreate after explicitly setActive() did not return the active context") + } + + test("Sessions of SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val session1 = sqlContext.newSession() + val session2 = sqlContext.newSession() + + // all have the default configurations + val key = SQLConf.SHUFFLE_PARTITIONS.key + assert(session1.getConf(key) === session2.getConf(key)) + session1.setConf(key, "1") + session2.setConf(key, "2") + assert(session1.getConf(key) === "1") + assert(session2.getConf(key) === "2") + + // temporary table should not be shared + val df = session1.range(10) + df.registerTempTable("test1") + assert(session1.tableNames().contains("test1")) + assert(!session2.tableNames().contains("test1")) + + // UDF should not be shared + def myadd(a: Int, b: Int): Int = a + b + session1.udf.register[Int, Int, Int]("myadd", myadd) + session1.sql("select myadd(1, 2)").explain() + intercept[AnalysisException] { + session2.sql("select myadd(1, 2)").explain() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 10e633f3cde46..c89a1516503e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -31,23 +31,16 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel new SparkConf().set("spark.sql.testkey", "true"))) } - // Make sure we set those test specific confs correctly when we create - // the SQLConf as well as when we call clear. - protected[sql] override def createSession(): SQLSession = new this.SQLSession() + protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { + clear() - clear() + override def clear(): Unit = { + super.clear() - override def clear(): Unit = { - super.clear() - - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.map { - case (key, value) => setConfString(key, value) - } + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.map { + case (key, value) => setConfString(key, value) } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 306f98bcb5344..719b03e1c7c71 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,19 +20,15 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Arrays, Map => JMap, UUID} +import java.util.{Arrays, UUID, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hive.service.cli._ -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.Utils +import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -40,7 +36,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} private[hive] class SparkExecuteStatementOperation( @@ -143,30 +139,15 @@ private[hive] class SparkExecuteStatementOperation( if (!runInBackground) { runInternal() } else { - val parentSessionState = SessionState.get() - val hiveConf = getConfigForOperation() val sparkServiceUGI = Utils.getUGI() - val sessionHive = getCurrentHive() - val currentSqlSession = hiveContext.currentSession // Runnable impl to call runInternal asynchronously, // from a different thread val backgroundOperation = new Runnable() { override def run(): Unit = { - val doAsAction = new PrivilegedExceptionAction[Object]() { - override def run(): Object = { - - // User information is part of the metastore client member in Hive - hiveContext.setSession(currentSqlSession) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = - hiveContext.executionHive.state.getConf.getClassLoader - sessionHive.getConf.setClassLoader(executionHiveClassLoader) - parentSessionState.getConf.setClassLoader(executionHiveClassLoader) - - Hive.set(sessionHive) - SessionState.setCurrentSessionState(parentSessionState) + val doAsAction = new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { try { runInternal() } catch { @@ -174,7 +155,6 @@ private[hive] class SparkExecuteStatementOperation( setOperationException(e) log.error("Error running hive query: ", e) } - return null } } @@ -191,7 +171,7 @@ private[hive] class SparkExecuteStatementOperation( try { // This submit blocks if no background threads are available to run this operation val backgroundHandle = - getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation) setBackgroundHandle(backgroundHandle) } catch { case rejected: RejectedExecutionException => @@ -210,6 +190,11 @@ private[hive] class SparkExecuteStatementOperation( statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + HiveThriftServer2.listener.onStatementStart( statementId, parentSession.getSessionHandle.getSessionId.toString, @@ -279,43 +264,4 @@ private[hive] class SparkExecuteStatementOperation( } } } - - /** - * If there are query specific settings to overlay, then create a copy of config - * There are two cases we need to clone the session config that's being passed to hive driver - * 1. Async query - - * If the client changes a config setting, that shouldn't reflect in the execution - * already underway - * 2. confOverlay - - * The query specific settings should only be applied to the query config and not session - * @return new configuration - * @throws HiveSQLException - */ - private def getConfigForOperation(): HiveConf = { - var sqlOperationConf = getParentSession().getHiveConf() - if (!getConfOverlay().isEmpty() || runInBackground) { - // clone the partent session config for this query - sqlOperationConf = new HiveConf(sqlOperationConf) - - // apply overlay query specific settings, if any - getConfOverlay().asScala.foreach { case (k, v) => - try { - sqlOperationConf.verifyAndSet(k, v) - } catch { - case e: IllegalArgumentException => - throw new HiveSQLException("Error applying statement specific settings", e) - } - } - } - return sqlOperationConf - } - - private def getCurrentHive(): Hive = { - try { - return Hive.get() - } catch { - case e: HiveException => - throw new HiveSQLException("Failed to get current Hive object", e); - } - } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 92ac0ec3fca29..33aaead3fbf96 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -36,7 +36,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: extends SessionManager(hiveServer) with ReflectedCompositeService { - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) @@ -60,13 +60,15 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { - hiveContext.openSession() val sessionHandle = super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + val ctx = hiveContext.newSession() + ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle } @@ -74,7 +76,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() + sparkSqlOperationManager.sessionToContexts.remove(sessionHandle) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index c8031ed0f3437..476651a559d2c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,20 +30,21 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) +private[thriftserver] class SparkSQLOperationManager() extends OperationManager with Logging { val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") val sessionToActivePool = Map[SessionHandle, String]() + val sessionToContexts = Map[SessionHandle, HiveContext]() override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - + val hiveContext = sessionToContexts(parentSession.getSessionHandle) val runInBackground = async && hiveContext.hiveThriftServerAsync val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(hiveContext, sessionToActivePool) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e59a14ec00d5c..76d1591a235c2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -96,7 +96,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" // If we haven't found all expected answers and another expected answer comes up... - if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) { next += 1 // If all expected answers have been found... if (next == expectedAnswers.size) { @@ -159,7 +159,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" -> "OK", "CACHE TABLE hive_test;" - -> "Time taken: ", + -> "", "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" @@ -180,7 +180,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" - -> "Time taken: " + -> "hive_test" ) runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( @@ -210,7 +210,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" - -> "Time taken:", + -> "", "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 19b2f24456ab0..ff8ca0150649d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -205,6 +205,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { import org.apache.spark.sql.SQLConf var defaultV1: String = null var defaultV2: String = null + var data: ArrayBuffer[Int] = null withMultipleConnectionJdbcStatement( // create table @@ -214,10 +215,16 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { "DROP TABLE IF EXISTS test_map", "CREATE TABLE test_map(key INT, value STRING)", s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", - "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC", + "CREATE DATABASE db1") queries.foreach(statement.execute) + val plan = statement.executeQuery("explain select * from test_table") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") val buf1 = new collection.mutable.ArrayBuffer[Int]() while (rs1.next()) { @@ -233,6 +240,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() assert(buf1 === buf2) + + data = buf1 }, // first session, we get the default value of the session status @@ -289,56 +298,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() }, - // accessing the cached data in another session + // try to access the cached data in another session { statement => - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) + // Cached temporary table can't be accessed by other sessions + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") } - rs1.close() - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) + val plan = statement.executeQuery("explain select key from test_map ORDER BY key DESC") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + + val rs = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf = new collection.mutable.ArrayBuffer[Int]() + while (rs.next()) { + buf += rs.getInt(1) } - rs2.close() + rs.close() + assert(buf === data) + }, - assert(buf1 === buf2) - statement.executeQuery("UNCACHE TABLE test_table") + // switch another database + { statement => + statement.execute("USE db1") - // TODO need to figure out how to determine if the data loaded from cache - val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf3 = new collection.mutable.ArrayBuffer[Int]() - while (rs3.next()) { - buf3 += rs3.getInt(1) + // there is no test_map table in db1 + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") } - rs3.close() - assert(buf1 === buf3) + statement.execute("CREATE TABLE test_map2(key INT, value STRING)") }, - // accessing the uncached table + // access default database { statement => - // TODO need to figure out how to determine if the data loaded from cache - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) - } - rs1.close() - - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) + // current database should still be `default` + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map2") } - rs2.close() - assert(buf1 === buf2) + statement.execute("USE db1") + // access test_map2 + statement.executeQuery("SELECT key from test_map2") } ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 17de8ef56f9a6..dad1e2347c387 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -25,7 +25,6 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions -import scala.concurrent.duration._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -34,32 +33,49 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.Logging -import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.sql._ +import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} -import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} +import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkContext} /** * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext */ -private[hive] class HiveQLDialect extends ParserDialect { +private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect { override def parse(sqlText: String): LogicalPlan = { - HiveQl.parseSql(sqlText) + sqlContext.executionHive.withHiveState { + HiveQl.parseSql(sqlText) + } + } +} + +/** + * Returns the current database of metadataHive. + */ +private[hive] case class CurrentDatabase(ctx: HiveContext) + extends LeafExpression with CodegenFallback { + override def dataType: DataType = StringType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = { + UTF8String.fromString(ctx.metadataHive.currentDatabase) } } @@ -69,13 +85,29 @@ private[hive] class HiveQLDialect extends ParserDialect { * * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { +class HiveContext private[hive]( + sc: SparkContext, + cacheManager: CacheManager, + @transient execHive: ClientWrapper, + @transient metaHive: ClientInterface) extends SQLContext(sc, cacheManager) with Logging { self => - import HiveContext._ + def this(sc: SparkContext) = this(sc, new CacheManager, null, null) + def this(sc: JavaSparkContext) = this(sc.sc) + + import org.apache.spark.sql.hive.HiveContext._ logDebug("create HiveContext") + /** + * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF, + * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader + * and Hive client (both of execution and metadata) with existing HiveContext. + */ + override def newSession(): HiveContext = { + new HiveContext(sc, cacheManager, executionHive.newSession(), metadataHive.newSession()) + } + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive @@ -157,14 +189,18 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient - protected[hive] lazy val executionHive: ClientWrapper = { + protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { + execHive + } else { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") - new ClientWrapper( + val loader = new IsolatedClientLoader( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + execJars = Seq(), config = newTemporaryConfiguration(), - initClassLoader = Utils.getContextOrSparkClassLoader) + isolationOn = false, + baseClassLoader = Utils.getContextOrSparkClassLoader) + loader.createClient().asInstanceOf[ClientWrapper] } - SessionState.setCurrentSessionState(executionHive.state) /** * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. @@ -182,7 +218,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * in the hive-site.xml file. */ @transient - protected[hive] lazy val metadataHive: ClientInterface = { + protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { + metaHive + } else { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options @@ -268,14 +306,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) } - isolatedLoader.client + isolatedLoader.createClient() } protected[sql] override def parseSql(sql: String): LogicalPlan = { - var state = SessionState.get() - if (state == null) { - SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) - } super.parseSql(substitutor.substitute(hiveconf, sql)) } @@ -384,8 +418,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } } - protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf - override def setConf(key: String, value: String): Unit = { super.setConf(key, value) executionHive.runSqlHive(s"SET $key=$value") @@ -402,7 +434,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { setConf(entry.key, entry.stringConverter(value)) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog @@ -410,7 +442,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) + + // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer + // can't access the SessionState of metadataHive. + functionRegistry.registerFunction( + "current_database", + (expressions: Seq[Expression]) => new CurrentDatabase(this)) /* An analyzer that uses the Hive metastore. */ @transient @@ -430,10 +468,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { ) } - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - /** Overridden by child classes that need to set configuration before the client init. */ protected def configure(): Map[String, String] = { // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch @@ -488,41 +522,40 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { }.toMap } - protected[hive] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - } - - /** - * SQLConf and HiveConf contracts: - * - * 1. reuse existing started SessionState if any - * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the - * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be - * set in the SQLConf *as well as* in the HiveConf. - */ - protected[hive] lazy val sessionState: SessionState = { - var state = SessionState.get() - if (state == null) { - state = new SessionState(new HiveConf(classOf[SessionState])) - SessionState.start(state) - } - state - } + /** + * SQLConf and HiveConf contracts: + * + * 1. create a new SessionState for each HiveContext + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. + */ + @transient + protected[hive] lazy val hiveconf: HiveConf = { + val c = executionHive.conf + setConf(c.getAllProperties) + c + } - protected[hive] lazy val hiveconf: HiveConf = { - setConf(sessionState.getConf.getAllProperties) - sessionState.getConf - } + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") { + protected[sql] override def dialectClassName = if (conf.dialect == "hiveql") { classOf[HiveQLDialect].getCanonicalName } else { super.dialectClassName } + protected[sql] override def getSQLDialect(): ParserDialect = { + if (conf.dialect == "hiveql") { + new HiveQLDialect(this) + } else { + super.getSQLDialect() + } + } + @transient private val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self @@ -598,6 +631,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { case _ => super.simpleString } } + + protected[sql] override def addJar(path: String): Unit = { + // Add jar to Hive and classloader + executionHive.addJar(path) + metadataHive.addJar(path) + Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) + super.addJar(path) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2bf22f5449641..250c232856885 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -25,29 +25,27 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.ql.{ErrorMsg, Context} -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} +import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -268,7 +266,7 @@ private[hive] object HiveQl extends Logging { node } - private def createContext(): Context = new Context(SessionState.get().getConf()) + private def createContext(): Context = new Context(hiveConf) private def getAst(sql: String, context: Context) = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) @@ -277,12 +275,16 @@ private[hive] object HiveQl extends Logging { * Returns the HiveConf */ private[this] def hiveConf: HiveConf = { - val ss = SessionState.get() // SessionState is lazy initialization, it can be null here + var ss = SessionState.get() + // SessionState is lazy initialization, it can be null here if (ss == null) { - new HiveConf() - } else { - ss.getConf + val original = Thread.currentThread().getContextClassLoader + val conf = new HiveConf(classOf[SessionState]) + conf.setClassLoader(original) + ss = new SessionState(conf) + SessionState.start(ss) } + ss.getConf } /** Returns a LogicalPlan for a given HiveQL string. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 915eae9d21e23..9d9a55edd7314 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -178,6 +178,15 @@ private[hive] trait ClientInterface { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + /** Add a jar into class loader */ + def addJar(path: String): Unit + + /** Return a ClientInterface as new session, that will share the class loader and Hive client */ + def newSession(): ClientInterface + + /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ + def withHiveState[A](f: => A): A + /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8f6d448b2aef4..3dce86c480747 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -60,7 +60,8 @@ import org.apache.spark.util.{CircularBuffer, Utils} private[hive] class ClientWrapper( override val version: HiveVersion, config: Map[String, String], - initClassLoader: ClassLoader) + initClassLoader: ClassLoader, + val clientLoader: IsolatedClientLoader) extends ClientInterface with Logging { @@ -150,31 +151,29 @@ private[hive] class ClientWrapper( // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) val ret = try { - val oldState = SessionState.get() - if (oldState == null) { - val initialConf = new HiveConf(classOf[SessionState]) - // HiveConf is a Hadoop Configuration, which has a field of classLoader and - // the initial value will be the current thread's context class loader - // (i.e. initClassLoader at here). - // We call initialConf.setClassLoader(initClassLoader) at here to make - // this action explicit. - initialConf.setClassLoader(initClassLoader) - config.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { - logDebug(s"Hive Config: $k=xxx") - } else { - logDebug(s"Hive Config: $k=$v") - } - initialConf.set(k, v) + val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) + config.foreach { case (k, v) => + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") } - val newState = new SessionState(initialConf) - SessionState.start(newState) - newState.out = new PrintStream(outputBuffer, true, "UTF-8") - newState.err = new PrintStream(outputBuffer, true, "UTF-8") - newState - } else { - oldState + initialConf.set(k, v) + } + val state = new SessionState(initialConf) + if (clientLoader.cachedHive != null) { + Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) } + SessionState.start(state) + state.out = new PrintStream(outputBuffer, true, "UTF-8") + state.err = new PrintStream(outputBuffer, true, "UTF-8") + state } finally { Thread.currentThread().setContextClassLoader(original) } @@ -188,11 +187,6 @@ private[hive] class ClientWrapper( conf.get(key, defaultValue) } - // TODO: should be a def?s - // When we create this val client, the HiveConf of it (conf) is the one associated with state. - @GuardedBy("this") - private var client = Hive.get(conf) - // We use hive's conf for compatibility. private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) @@ -200,7 +194,7 @@ private[hive] class ClientWrapper( /** * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. */ - private def retryLocked[A](f: => A): A = synchronized { + private def retryLocked[A](f: => A): A = clientLoader.synchronized { // Hive sometimes retries internally, so set a deadline to avoid compounding delays. val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong var numTries = 0 @@ -215,13 +209,8 @@ private[hive] class ClientWrapper( logWarning( "HiveClientWrapper got thrift exception, destroying client and retrying " + s"(${retryLimit - numTries} tries remaining)", e) + clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) - try { - client = Hive.get(state.getConf, true) - } catch { - case e: Exception if causedByThrift(e) => - logWarning("Failed to refresh hive client, will retry.", e) - } } } while (numTries <= retryLimit && System.nanoTime < deadline) if (System.nanoTime > deadline) { @@ -242,13 +231,26 @@ private[hive] class ClientWrapper( false } + def client: Hive = { + if (clientLoader.cachedHive != null) { + clientLoader.cachedHive.asInstanceOf[Hive] + } else { + val c = Hive.get(conf) + clientLoader.cachedHive = c + c + } + } + /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = retryLocked { + def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) + // The classloader in clientLoader could be changed after addJar, always use the latest + // classloader + state.getConf.setClassLoader(clientLoader.classLoader) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. @@ -545,6 +547,15 @@ private[hive] class ClientWrapper( listBucketingEnabled) } + def addJar(path: String): Unit = { + clientLoader.addJar(path) + runSqlHive(s"ADD JAR $path") + } + + def newSession(): ClientWrapper = { + clientLoader.createClient().asInstanceOf[ClientWrapper] + } + def reset(): Unit = withHiveState { client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 1fe4cba9571f3..567e4d7b411ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,6 +22,7 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util +import scala.collection.mutable import scala.language.reflectiveCalls import scala.util.Try @@ -148,53 +149,75 @@ private[hive] class IsolatedClientLoader( name.replaceAll("\\.", "/") + ".class" /** The classloader that is used to load an isolated version of Hive. */ - protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) { - override def loadClass(name: String, resolve: Boolean): Class[_] = { - val loaded = findLoadedClass(name) - if (loaded == null) doLoadClass(name, resolve) else loaded - } - - def doLoadClass(name: String, resolve: Boolean): Class[_] = { - val classFileName = name.replaceAll("\\.", "/") + ".class" - if (isBarrierClass(name) && isolationOn) { - // For barrier classes, we construct a new copy of the class. - val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) - logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") - defineClass(name, bytes, 0, bytes.length) - } else if (!isSharedClass(name)) { - logDebug(s"hive class: $name - ${getResource(classToPath(name))}") - super.loadClass(name, resolve) - } else { - // For shared classes, we delegate to baseClassLoader. - logDebug(s"shared class: $name") - baseClassLoader.loadClass(name) + private[hive] var classLoader: ClassLoader = if (isolationOn) { + new URLClassLoader(allJars, rootClassLoader) { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + val loaded = findLoadedClass(name) + if (loaded == null) doLoadClass(name, resolve) else loaded + } + def doLoadClass(name: String, resolve: Boolean): Class[_] = { + val classFileName = name.replaceAll("\\.", "/") + ".class" + if (isBarrierClass(name)) { + // For barrier classes, we construct a new copy of the class. + val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) + logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") + defineClass(name, bytes, 0, bytes.length) + } else if (!isSharedClass(name)) { + logDebug(s"hive class: $name - ${getResource(classToPath(name))}") + super.loadClass(name, resolve) + } else { + // For shared classes, we delegate to baseClassLoader. + logDebug(s"shared class: $name") + baseClassLoader.loadClass(name) + } } } + } else { + baseClassLoader } - // Pre-reflective instantiation setup. - logDebug("Initializing the logger to avoid disaster...") - Thread.currentThread.setContextClassLoader(classLoader) + private[hive] def addJar(path: String): Unit = synchronized { + val jarURL = new java.io.File(path).toURI.toURL + // TODO: we should avoid of stacking classloaders (use a single URLClassLoader and add jars + // to that) + classLoader = new java.net.URLClassLoader(Array(jarURL), classLoader) + } /** The isolated client interface to Hive. */ - val client: ClientInterface = try { - classLoader - .loadClass(classOf[ClientWrapper].getName) - .getConstructors.head - .newInstance(version, config, classLoader) - .asInstanceOf[ClientInterface] - } catch { - case e: InvocationTargetException => - if (e.getCause().isInstanceOf[NoClassDefFoundError]) { - val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] - throw new ClassNotFoundException( - s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + - "Please make sure that jars for your version of hive and hadoop are included in the " + - s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") - } else { - throw e - } - } finally { - Thread.currentThread.setContextClassLoader(baseClassLoader) + private[hive] def createClient(): ClientInterface = { + if (!isolationOn) { + return new ClientWrapper(version, config, baseClassLoader, this) + } + // Pre-reflective instantiation setup. + logDebug("Initializing the logger to avoid disaster...") + val origLoader = Thread.currentThread().getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + + try { + classLoader + .loadClass(classOf[ClientWrapper].getName) + .getConstructors.head + .newInstance(version, config, classLoader, this) + .asInstanceOf[ClientInterface] + } catch { + case e: InvocationTargetException => + if (e.getCause().isInstanceOf[NoClassDefFoundError]) { + val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] + throw new ClassNotFoundException( + s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + + "Please make sure that jars for your version of hive and hadoop are included in the " + + s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + } else { + throw e + } + } finally { + Thread.currentThread.setContextClassLoader(origLoader) + } } + + /** + * The place holder for shared Hive client for all the HiveContext sessions (they share an + * IsolatedClientLoader). + */ + private[hive] var cachedHive: Any = null } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 9f654eed5761c..51ec92afd06ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.hive.metastore.MetaStoreUtils + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * Analyzes the given table in the current database to generate statistics, which will be @@ -86,26 +86,7 @@ case class AddJar(path: String) extends RunnableCommand { } override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val currentClassLoader = Utils.getContextOrSparkClassLoader - - // Add jar to current context - val jarURL = new java.io.File(path).toURI.toURL - val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) - Thread.currentThread.setContextClassLoader(newClassLoader) - // We need to explicitly set the class loader associated with the conf in executionHive's - // state because this class loader will be used as the context class loader of the current - // thread to execute any Hive command. - // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get() - // returns the value of a thread local variable and its HiveConf may not be the HiveConf - // associated with `executionHive.state` (for example, HiveContext is created in one thread - // and then add jar is called from another thread). - hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader) - // Add jar to isolated hive (metadataHive) class loader. - hiveContext.runSqlHive(s"ADD JAR $path") - - // Add jar to executors - hiveContext.sparkContext.addJar(path) + sqlContext.addJar(path) Seq(Row(0)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index be335a47dcabd..ff39ccb7c1ea5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -116,27 +116,18 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) - // Make sure we set those test specific confs correctly when we create - // the SQLConf as well as when we call clear. - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[hive] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. - // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" - override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + protected[sql] override lazy val conf: SQLConf = new SQLConf { + // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" + override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - clear() + clear() - override def clear(): Unit = { - super.clear() + override def clear(): Unit = { + super.clear() - TestHiveContext.overrideConfs.map { - case (key, value) => setConfString(key, value) - } + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 79cf40aba4bf2..528a7398b10df 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { - override def beforeAll() { - if (SessionState.get() == null) { - SessionState.start(new HiveConf()) - } - } - private def extractTableDesc(sql: String): (HiveTable, Boolean) = { HiveQl.createPlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 2da22ec2379f3..c6d034a23a1c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -53,7 +53,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test("success sanity check") { val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, buildConf(), - ivyPath).client + ivyPath).createClient() val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -83,7 +83,7 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -97,7 +97,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient() } test(s"$version: createDatabase") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fe63ad5683195..2878500453141 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1133,6 +1133,38 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { conf.clear() } + test("current_database with multiple sessions") { + sql("create database a") + sql("use a") + val s2 = newSession() + s2.sql("create database b") + s2.sql("use b") + + assert(sql("select current_database()").first() === Row("a")) + assert(s2.sql("select current_database()").first() === Row("b")) + + try { + sql("create table test_a(key INT, value STRING)") + s2.sql("create table test_b(key INT, value STRING)") + + sql("select * from test_a") + intercept[AnalysisException] { + sql("select * from test_b") + } + sql("select * from b.test_b") + + s2.sql("select * from test_b") + intercept[AnalysisException] { + s2.sql("select * from test_a") + } + s2.sql("select * from a.test_a") + } finally { + sql("DROP TABLE IF EXISTS test_a") + s2.sql("DROP TABLE IF EXISTS test_b") + } + + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ec5b83b98e401..ccc15eaa63f42 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -160,10 +160,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allFunctions = + val allBuiltinFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted - checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + // The TestContext is shared by all the test cases, some functions may be registered before + // this, so we check that all the builtin functions are returned. + val allFunctions = sql("SHOW functions").collect().map(r => r(0)) + allBuiltinFunctions.foreach { f => + assert(allFunctions.contains(f)) + } checkAnswer(sql("SHOW functions abs"), Row("abs")) checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) From 8e67882b905683a1f151679214ef0b575e77c7e1 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 8 Oct 2015 18:34:15 -0700 Subject: [PATCH 008/139] =?UTF-8?q?[SPARK-10973]=20[ML]=20[PYTHON]=20=5F?= =?UTF-8?q?=5Fgettitem=5F=5F=20method=20throws=20IndexError=20exception=20?= =?UTF-8?q?when=20we=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit __gettitem__ method throws IndexError exception when we try to access index after the last non-zero entry from pyspark.mllib.linalg import Vectors sv = Vectors.sparse(5, {1: 3}) sv[0] ## 0.0 sv[1] ## 3.0 sv[2] ## Traceback (most recent call last): ## File "", line 1, in ## File "/python/pyspark/mllib/linalg/__init__.py", line 734, in __getitem__ ## row_ind = inds[insert_index] ## IndexError: index out of bounds Author: zero323 Closes #9009 from zero323/sparse_vector_index_error. --- python/pyspark/mllib/linalg/__init__.py | 3 +++ python/pyspark/mllib/tests.py | 12 +++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index ea42127f1651f..d903b9030d8ce 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -770,6 +770,9 @@ def __getitem__(self, index): raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) + if insert_index >= inds.size: + return 0. + row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 96cf13495aa95..2a6a5cd3fe40e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -237,15 +237,17 @@ def test_conversion(self): self.assertTrue(dv.array.dtype == 'float64') def test_sparse_vector_indexing(self): - sv = SparseVector(4, {1: 1, 3: 2}) + sv = SparseVector(5, {1: 1, 3: 2}) self.assertEqual(sv[0], 0.) self.assertEqual(sv[3], 2.) self.assertEqual(sv[1], 1.) self.assertEqual(sv[2], 0.) - self.assertEqual(sv[-1], 2) - self.assertEqual(sv[-2], 0) - self.assertEqual(sv[-4], 0) - for ind in [4, -5]: + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) From fa3e4d8f52995bf632e7eda60dbb776c9f637546 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Thu, 8 Oct 2015 18:50:27 -0700 Subject: [PATCH 009/139] =?UTF-8?q?[SPARK-11019]=20[STREAMING]=20[FLUME]?= =?UTF-8?q?=20Gracefully=20shutdown=20Flume=20receiver=20th=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …reads. Wait for a minute for the receiver threads to shutdown before interrupting them. Author: Hari Shreedharan Closes #9041 from harishreedharan/flume-graceful-shutdown. --- .../spark/streaming/flume/FlumePollingInputDStream.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 3b936d88abd3e..6737750c3d63e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{LinkedBlockingQueue, Executors} +import java.util.concurrent.{Executors, LinkedBlockingQueue, TimeUnit} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -93,7 +93,11 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") - receiverExecutor.shutdownNow() + receiverExecutor.shutdown() + // Wait upto a minute for the threads to die + if (!receiverExecutor.awaitTermination(60, TimeUnit.SECONDS)) { + receiverExecutor.shutdownNow() + } connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } From 09841290055770a619a2e72fbaef1a5e694916ae Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Thu, 8 Oct 2015 18:53:38 -0700 Subject: [PATCH 010/139] [SPARK-10955] [STREAMING] Add a warning if dynamic allocation for Streaming applications Dynamic allocation can be painful for streaming apps and can lose data. Log a warning for streaming applications if dynamic allocation is enabled. Author: Hari Shreedharan Closes #8998 from harishreedharan/ss-log-error and squashes the following commits: 462b264 [Hari Shreedharan] Improve log message. 2733d94 [Hari Shreedharan] Minor change to warning message. eaa48cc [Hari Shreedharan] Log a warning instead of failing the application if dynamic allocation is enabled. 725f090 [Hari Shreedharan] Add config parameter to allow dynamic allocation if the user explicitly sets it. b3f9a95 [Hari Shreedharan] Disable dynamic allocation and kill app if it is enabled. a4a5212 [Hari Shreedharan] [streaming] SPARK-10955. Disable dynamic allocation for Streaming applications. --- .../org/apache/spark/streaming/StreamingContext.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 94fea63f55b25..9b2632c229548 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -564,6 +564,13 @@ class StreamingContext private[streaming] ( ) } } + + if (Utils.isDynamicAllocationEnabled(sc.conf)) { + logWarning("Dynamic Allocation is enabled for this application. " + + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + + "Write Ahead Log is not enabled for non-replayable sources like Flume. " + + "See the programming guide for details on how to enable the Write Ahead Log") + } } /** From 67fbecbf32fced87d3accd2618fef2af9f44fae2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 8 Oct 2015 21:44:59 -0700 Subject: [PATCH 011/139] [SPARK-10956] Common MemoryManager interface for storage and execution This patch introduces a `MemoryManager` that is the central arbiter of how much memory to grant to storage and execution. This patch is primarily concerned only with refactoring while preserving the existing behavior as much as possible. This is the first step away from the existing rigid separation of storage and execution memory, which has several major drawbacks discussed on the [issue](https://issues.apache.org/jira/browse/SPARK-10956). It is the precursor of a series of patches that will attempt to address those drawbacks. Author: Andrew Or Author: Josh Rosen Author: andrewor14 Closes #9000 from andrewor14/memory-manager. --- .../scala/org/apache/spark/SparkEnv.scala | 11 +- .../apache/spark/memory/MemoryManager.scala | 117 ++++++++ .../spark/memory/StaticMemoryManager.scala | 202 +++++++++++++ .../spark/shuffle/ShuffleMemoryManager.scala | 69 +++-- .../apache/spark/storage/BlockManager.scala | 33 +-- .../apache/spark/storage/MemoryStore.scala | 272 +++++++++--------- .../memory/StaticMemoryManagerSuite.scala | 172 +++++++++++ .../BlockManagerReplicationSuite.scala | 29 +- .../spark/storage/BlockManagerSuite.scala | 34 ++- .../execution/TestShuffleMemoryManager.scala | 28 +- .../streaming/ReceivedBlockHandlerSuite.scala | 13 +- 11 files changed, 752 insertions(+), 228 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/memory/MemoryManager.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index cfde27fb2e7d3..df3d84a1f08e9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -30,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.memory.{MemoryManager, StaticMemoryManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} @@ -69,6 +70,8 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + // TODO: unify these *MemoryManager classes (SPARK-10984) + val memoryManager: MemoryManager, val shuffleMemoryManager: ShuffleMemoryManager, val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, @@ -332,7 +335,8 @@ object SparkEnv extends Logging { val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) + val memoryManager = new StaticMemoryManager(conf) + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores) val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) @@ -343,8 +347,8 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, - numUsableCores) + serializer, conf, memoryManager, mapOutputTracker, shuffleManager, + blockTransferService, securityManager, numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) @@ -417,6 +421,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + memoryManager, shuffleMemoryManager, executorMemoryManager, outputCommitCoordinator, diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala new file mode 100644 index 0000000000000..4bf73b696920d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} + + +/** + * An abstract memory manager that enforces how memory is shared between execution and storage. + * + * In this context, execution memory refers to that used for computation in shuffles, joins, + * sorts and aggregations, while storage memory refers to that used for caching and propagating + * internal data across the cluster. There exists one of these per JVM. + */ +private[spark] abstract class MemoryManager { + + // The memory store used to evict cached blocks + private var _memoryStore: MemoryStore = _ + protected def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalArgumentException("memory store not initialized yet") + } + _memoryStore + } + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Acquire N bytes of memory for execution. + * @return number of bytes successfully granted (<= N). + */ + def acquireExecutionMemory(numBytes: Long): Long + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + + /** + * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + + /** + * Release N bytes of execution memory. + */ + def releaseExecutionMemory(numBytes: Long): Unit + + /** + * Release N bytes of storage memory. + */ + def releaseStorageMemory(numBytes: Long): Unit + + /** + * Release all storage memory acquired. + */ + def releaseStorageMemory(): Unit + + /** + * Release N bytes of unroll memory. + */ + def releaseUnrollMemory(numBytes: Long): Unit + + /** + * Total available memory for execution, in bytes. + */ + def maxExecutionMemory: Long + + /** + * Total available memory for storage, in bytes. + */ + def maxStorageMemory: Long + + /** + * Execution memory currently in use, in bytes. + */ + def executionMemoryUsed: Long + + /** + * Storage memory currently in use, in bytes. + */ + def storageMemoryUsed: Long + +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala new file mode 100644 index 0000000000000..150445edb9578 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.{BlockId, BlockStatus} + + +/** + * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. + * + * The sizes of the execution and storage regions are determined through + * `spark.shuffle.memoryFraction` and `spark.storage.memoryFraction` respectively. The two + * regions are cleanly separated such that neither usage can borrow memory from the other. + */ +private[spark] class StaticMemoryManager( + conf: SparkConf, + override val maxExecutionMemory: Long, + override val maxStorageMemory: Long) + extends MemoryManager with Logging { + + // Max number of bytes worth of blocks to evict when unrolling + private val maxMemoryToEvictForUnroll: Long = { + (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong + } + + // Amount of execution / storage memory in use + // Accesses must be synchronized on `this` + private var _executionMemoryUsed: Long = 0 + private var _storageMemoryUsed: Long = 0 + + def this(conf: SparkConf) { + this( + conf, + StaticMemoryManager.getMaxExecutionMemory(conf), + StaticMemoryManager.getMaxStorageMemory(conf)) + } + + /** + * Acquire N bytes of memory for execution. + * @return number of bytes successfully granted (<= N). + */ + override def acquireExecutionMemory(numBytes: Long): Long = synchronized { + assert(_executionMemoryUsed <= maxExecutionMemory) + val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) + _executionMemoryUsed += bytesToGrant + bytesToGrant + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) + } + + /** + * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * + * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage + * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any, + * are added to `evictedBlocks`. + * + * @return whether all N bytes were successfully granted. + */ + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + val currentUnrollMemory = memoryStore.currentUnrollMemory + val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) + val numBytesToFree = math.min(numBytes, maxNumBytesToFree) + acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the size of space to be freed through evicting blocks + * @param evictedBlocks a holder for blocks evicted in the process + * @return whether all N bytes were successfully granted. + */ + private def acquireStorageMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + // Note: Keep this outside synchronized block to avoid potential deadlocks! + memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) + synchronized { + assert(_storageMemoryUsed <= maxStorageMemory) + val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory + if (enoughMemory) { + _storageMemoryUsed += numBytesToAcquire + } + enoughMemory + } + } + + /** + * Release N bytes of execution memory. + */ + override def releaseExecutionMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _executionMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of execution " + + s"memory when we only have ${_executionMemoryUsed} bytes") + _executionMemoryUsed = 0 + } else { + _executionMemoryUsed -= numBytes + } + } + + /** + * Release N bytes of storage memory. + */ + override def releaseStorageMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _storageMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of storage " + + s"memory when we only have ${_storageMemoryUsed} bytes") + _storageMemoryUsed = 0 + } else { + _storageMemoryUsed -= numBytes + } + } + + /** + * Release all storage memory acquired. + */ + override def releaseStorageMemory(): Unit = synchronized { + _storageMemoryUsed = 0 + } + + /** + * Release N bytes of unroll memory. + */ + override def releaseUnrollMemory(numBytes: Long): Unit = { + releaseStorageMemory(numBytes) + } + + /** + * Amount of execution memory currently in use, in bytes. + */ + override def executionMemoryUsed: Long = synchronized { + _executionMemoryUsed + } + + /** + * Amount of storage memory currently in use, in bytes. + */ + override def storageMemoryUsed: Long = synchronized { + _storageMemoryUsed + } + +} + + +private[spark] object StaticMemoryManager { + + /** + * Return the total amount of memory available for the storage region, in bytes. + */ + private def getMaxStorageMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) + val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + + /** + * Return the total amount of memory available for the execution region, in bytes. + */ + private def getMaxExecutionMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 9839c7640cc63..bb64bb3f35df0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -21,8 +21,9 @@ import scala.collection.mutable import com.google.common.annotations.VisibleForTesting +import org.apache.spark._ +import org.apache.spark.memory.{StaticMemoryManager, MemoryManager} import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling @@ -40,16 +41,17 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} * * Use `ShuffleMemoryManager.create()` factory method to create a new instance. * - * @param maxMemory total amount of memory available for execution, in bytes. + * @param memoryManager the interface through which this manager acquires execution memory * @param pageSizeBytes number of bytes for each page, by default. */ private[spark] class ShuffleMemoryManager protected ( - val maxMemory: Long, + memoryManager: MemoryManager, val pageSizeBytes: Long) extends Logging { private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes + private val maxMemory = memoryManager.maxExecutionMemory private def currentTaskAttemptId(): Long = { // In case this is called on the driver, return an invalid task attempt id. @@ -71,7 +73,7 @@ class ShuffleMemoryManager protected ( // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire if (!taskMemory.contains(taskAttemptId)) { taskMemory(taskAttemptId) = 0L - notifyAll() // Will later cause waiting tasks to wake up and check numThreads again + notifyAll() // Will later cause waiting tasks to wake up and check numTasks again } // Keep looping until we're either sure that we don't want to grant this request (because this @@ -85,46 +87,57 @@ class ShuffleMemoryManager protected ( // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, freeMemory) if (curMem < maxMemory / (2 * numActiveTasks)) { // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; // if we can't give it this much now, wait for other tasks to free up memory // (this happens if older tasks allocated lots of memory before N grew) if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant + return acquire(toGrant) } else { logInfo( s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { - // Only give it as much memory as is free, which might be none if it reached 1 / numThreads - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant + return acquire(toGrant) } } 0L // Never reached } + /** + * Acquire N bytes of execution memory from the memory manager for the current task. + * @return number of bytes actually acquired (<= N). + */ + private def acquire(numBytes: Long): Long = synchronized { + val taskAttemptId = currentTaskAttemptId() + val acquired = memoryManager.acquireExecutionMemory(numBytes) + taskMemory(taskAttemptId) += acquired + acquired + } + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { val taskAttemptId = currentTaskAttemptId() val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") + s"Internal error: release called on $numBytes bytes but task only has $curMem") } taskMemory(taskAttemptId) -= numBytes + memoryManager.releaseExecutionMemory(numBytes) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ def releaseMemoryForThisTask(): Unit = synchronized { val taskAttemptId = currentTaskAttemptId() - taskMemory.remove(taskAttemptId) + taskMemory.remove(taskAttemptId).foreach { numBytes => + memoryManager.releaseExecutionMemory(numBytes) + } notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } @@ -138,30 +151,28 @@ class ShuffleMemoryManager protected ( private[spark] object ShuffleMemoryManager { - def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { - val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + def create( + conf: SparkConf, + memoryManager: MemoryManager, + numCores: Int): ShuffleMemoryManager = { + val maxMemory = memoryManager.maxExecutionMemory val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) - new ShuffleMemoryManager(maxMemory, pageSize) + new ShuffleMemoryManager(memoryManager, pageSize) } + /** + * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size. + */ def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { - new ShuffleMemoryManager(maxMemory, pageSizeBytes) + val conf = new SparkConf + val memoryManager = new StaticMemoryManager( + conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue) + new ShuffleMemoryManager(memoryManager, pageSizeBytes) } @VisibleForTesting def createForTesting(maxMemory: Long): ShuffleMemoryManager = { - new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) - } - - /** - * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction - * of the memory pool and a safety factor since collections can sometimes grow bigger than - * the size we target before we estimate their sizes again. - */ - private def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + create(maxMemory, 4 * 1024 * 1024) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 47bd2ef8b2941..9f5bd2abbdc5d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -31,6 +31,7 @@ import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.MemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf @@ -64,8 +65,8 @@ private[spark] class BlockManager( rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, - maxMemory: Long, val conf: SparkConf, + memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, @@ -82,12 +83,15 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false - private[spark] val memoryStore = new MemoryStore(this, maxMemory) + private[spark] val memoryStore = new MemoryStore(this, memoryManager) private[spark] val diskStore = new DiskStore(this, diskBlockManager) private[spark] lazy val externalBlockStore: ExternalBlockStore = { externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) } + memoryManager.setMemoryStore(memoryStore) + + private val maxMemory = memoryManager.maxStorageMemory private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -157,24 +161,6 @@ private[spark] class BlockManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this( - execId: String, - rpcEnv: RpcEnv, - master: BlockManagerMaster, - serializer: Serializer, - conf: SparkConf, - mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService, - securityManager: SecurityManager, - numUsableCores: Int) = { - this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - } - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -1267,13 +1253,6 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { private val ID_GENERATOR = new IdGenerator - /** Return the total amount of storage memory available. */ - private def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) - val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 6f27f00307f8c..35c57b923c43a 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.memory.MemoryManager import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -33,13 +34,12 @@ private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ -private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) +private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager) extends BlockStore(blockManager) { private val conf = blockManager.conf private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) - - @volatile private var currentMemory = 0L + private val maxMemory = memoryManager.maxStorageMemory // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object @@ -56,15 +56,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // memory (SPARK-4777). private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() - /** - * The amount of space ensured for unrolling values in memory, shared across all cores. - * This space is not reserved in advance, but allocated dynamically by dropping existing blocks. - */ - private val maxUnrollMemory: Long = { - val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2) - (maxMemory * unrollFraction).toLong - } - // Initial memory to request before unrolling any block private val unrollMemoryThreshold: Long = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) @@ -77,8 +68,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) - /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ - def freeMemory: Long = maxMemory - currentMemory + /** Total storage memory used including unroll memory, in bytes. */ + private def memoryUsed: Long = memoryManager.storageMemoryUsed + + /** + * Amount of storage memory, in bytes, used for caching blocks. + * This does not include memory used for unrolling. + */ + private def blocksMemoryUsed: Long = memoryUsed - currentUnrollMemory override def getSize(blockId: BlockId): Long = { entries.synchronized { @@ -94,8 +91,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -108,15 +106,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks) val data = - if (putAttempt.success) { + if (putSuccess) { assert(bytes.limit == size) Right(bytes.duplicate()) } else { null } - PutResult(size, data, putAttempt.droppedBlocks) + PutResult(size, data, droppedBlocks) } override def putArray( @@ -124,14 +123,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - val putAttempt = tryToPut(blockId, values, sizeEstimate, deserialized = true) - PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks) + tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks) + PutResult(sizeEstimate, Left(values.iterator), droppedBlocks) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -209,23 +209,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def remove(blockId: BlockId): Boolean = { - entries.synchronized { - val entry = entries.remove(blockId) - if (entry != null) { - currentMemory -= entry.size - logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") - true - } else { - false - } + val entry = entries.synchronized { entries.remove(blockId) } + if (entry != null) { + memoryManager.releaseStorageMemory(entry.size) + logDebug(s"Block $blockId of size ${entry.size} dropped " + + s"from memory (free ${maxMemory - blocksMemoryUsed})") + true + } else { + false } } override def clear() { entries.synchronized { entries.clear() - currentMemory = 0 } + memoryManager.releaseStorageMemory() logInfo("MemoryStore cleared") } @@ -265,7 +264,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -281,20 +280,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - // Hold the accounting lock, in case another thread concurrently puts a block that - // takes up the unrolling space we just ensured here - accountingLock.synchronized { - if (!reserveUnrollMemoryForThisTask(amountToRequest)) { - // If the first request is not granted, try again after ensuring free space - // If there is still not enough space, give up and drop the partition - val spaceToEnsure = maxUnrollMemory - currentUnrollMemory - if (spaceToEnsure > 0) { - val result = ensureFreeSpace(blockId, spaceToEnsure) - droppedBlocks ++= result.droppedBlocks - } - keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) - } - } + keepUnrolling = reserveUnrollMemoryForThisTask( + blockId, amountToRequest, droppedBlocks) // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -317,10 +304,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Otherwise, if we return an iterator, we release the memory reserved here // later when the task finishes. if (keepUnrolling) { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved - releaseUnrollMemoryForThisTask(amountToRelease) - reservePendingUnrollMemoryForThisTask(amountToRelease) + // Here, we transfer memory from unroll to pending unroll because we expect to cache this + // block in `tryToPut`. We do not release and re-acquire memory from the MemoryManager in + // order to avoid race conditions where another component steals the memory that we're + // trying to transfer. + val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved + unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending } } } @@ -337,8 +330,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId: BlockId, value: Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { - tryToPut(blockId, () => value, size, deserialized) + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + tryToPut(blockId, () => value, size, deserialized, droppedBlocks) } /** @@ -354,13 +348,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * blocks to free memory for one block, another thread may use up the freed space for * another block. * - * Return whether put was successful, along with the blocks dropped in the process. + * All blocks evicted in the process, if any, will be added to `droppedBlocks`. + * + * @return whether put was successful. */ private def tryToPut( blockId: BlockId, value: () => Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has @@ -368,24 +365,27 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * for freeing up more space for another block that needs to be put. Only then the actually * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */ - var putSuccess = false - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - accountingLock.synchronized { - val freeSpaceResult = ensureFreeSpace(blockId, size) - val enoughFreeSpace = freeSpaceResult.success - droppedBlocks ++= freeSpaceResult.droppedBlocks - - if (enoughFreeSpace) { + // Note: if we have previously unrolled this block successfully, then pending unroll + // memory should be non-zero. This is the amount that we already reserved during the + // unrolling process. In this case, we can just reuse this space to cache our block. + // + // Note: the StaticMemoryManager counts unroll memory as storage memory. Here, the + // synchronization on `accountingLock` guarantees that the release of unroll memory and + // acquisition of storage memory happens atomically. However, if storage memory is acquired + // outside of MemoryStore or if unroll memory is counted as execution memory, then we will + // have to revisit this assumption. See SPARK-10983 for more context. + releasePendingUnrollMemoryForThisTask() + val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) + if (enoughMemory) { + // We acquired enough memory for the block, so go ahead and put it val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) - currentMemory += size } val valuesOrBytes = if (deserialized) "values" else "bytes" logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - putSuccess = true + blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. @@ -397,10 +397,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } - // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisTask() + enoughMemory } - ResultWithDroppedBlocks(putSuccess, droppedBlocks) } /** @@ -409,40 +407,42 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping - * blocks. Otherwise, the freed space may fill up before the caller puts in their new value. - * - * Return whether there is enough free space, along with the blocks dropped in the process. + * @param blockId the ID of the block we are freeing space for + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether there is enough free space. */ - private def ensureFreeSpace( - blockIdToAdd: BlockId, - space: Long): ResultWithDroppedBlocks = { - logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") - - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + private[spark] def ensureFreeSpace( + blockId: BlockId, + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + accountingLock.synchronized { + val freeMemory = maxMemory - memoryUsed + val rddToAdd = getRddId(blockId) + val selectedBlocks = new ArrayBuffer[BlockId] + var selectedMemory = 0L - if (space > maxMemory) { - logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit") - return ResultWithDroppedBlocks(success = false, droppedBlocks) - } + logInfo(s"Ensuring $space bytes of free space for block $blockId " + + s"(free: $freeMemory, max: $maxMemory)") - // Take into account the amount of memory currently occupied by unrolling blocks - // and minus the pending unroll memory for that block on current thread. - val taskAttemptId = currentTaskAttemptId() - val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + // Fail fast if the block simply won't fit + if (space > maxMemory) { + logInfo(s"Will not store $blockId as the required space " + + s"($space bytes) than our memory limit ($maxMemory bytes)") + return false + } - if (actualFreeMemory < space) { - val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L + // No need to evict anything if there is already enough free space + if (freeMemory >= space) { + return true + } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (actualFreeMemory + selectedMemory < space && iterator.hasNext) { + while (freeMemory + selectedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { @@ -452,7 +452,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - if (actualFreeMemory + selectedMemory >= space) { + if (freeMemory + selectedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } @@ -469,14 +469,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } } - return ResultWithDroppedBlocks(success = true, droppedBlocks) + true } else { - logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " + + logInfo(s"Will not store $blockId as it would require dropping another block " + "from the same RDD") - return ResultWithDroppedBlocks(success = false, droppedBlocks) + false } } - ResultWithDroppedBlocks(success = true, droppedBlocks) } override def contains(blockId: BlockId): Boolean = { @@ -489,17 +488,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Reserve additional memory for unrolling blocks used by this task. - * Return whether the request is granted. + * Reserve memory for unrolling the given block for this task. + * @return whether the request is granted. */ - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask( + blockId: BlockId, + memory: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { accountingLock.synchronized { - val granted = freeMemory > currentUnrollMemory + memory - if (granted) { + // Note: all acquisitions of unroll memory must be synchronized on `accountingLock` + val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) + if (success) { val taskAttemptId = currentTaskAttemptId() unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } - granted + success } } @@ -507,40 +510,38 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this task for unrolling blocks. * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - if (memory < 0) { - unrollMemoryMap.remove(taskAttemptId) - } else { - unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory - // If this task claims no more unroll memory, release it completely - if (unrollMemoryMap(taskAttemptId) <= 0) { - unrollMemoryMap.remove(taskAttemptId) + if (unrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + unrollMemoryMap(taskAttemptId) -= memoryToRelease + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) } } } } - /** - * Reserve the unroll memory of current unroll successful block used by this task - * until actually put the block into memory entry. - */ - def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { - val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory - } - } - /** * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisTask(): Unit = { + def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(taskAttemptId) + if (pendingUnrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease + if (pendingUnrollMemoryMap(taskAttemptId) == 0) { + pendingUnrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) + } + } } } @@ -561,19 +562,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Return the number of tasks currently unrolling blocks. */ - def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + private def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. */ - def logMemoryUsage(): Unit = { - val blocksMemory = currentMemory - val unrollMemory = currentUnrollMemory - val totalMemory = blocksMemory + unrollMemory + private def logMemoryUsage(): Unit = { logInfo( - s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + - s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + + s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } @@ -584,7 +582,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * @param blockId ID of the block we are trying to unroll. * @param finalVectorSize Final size of the vector before unrolling failed. */ - def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { logWarning( s"Not enough space to cache $blockId in memory! " + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" @@ -592,7 +590,3 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logMemoryUsage() } } - -private[spark] case class ResultWithDroppedBlocks( - success: Boolean, - droppedBlocks: Seq[(BlockId, BlockStatus)]) diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala new file mode 100644 index 0000000000000..c436a8b5c9f81 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable.ArrayBuffer + +import org.mockito.Mockito.{mock, reset, verify, when} +import org.mockito.Matchers.{any, eq => meq} + +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.{SparkConf, SparkFunSuite} + + +class StaticMemoryManagerSuite extends SparkFunSuite { + private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") + + test("basic execution memory") { + val maxExecutionMem = 1000L + val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) + assert(mm.executionMemoryUsed === 0L) + assert(mm.acquireExecutionMemory(10L) === 10L) + assert(mm.executionMemoryUsed === 10L) + assert(mm.acquireExecutionMemory(100L) === 100L) + // Acquire up to the max + assert(mm.acquireExecutionMemory(1000L) === 890L) + assert(mm.executionMemoryUsed === maxExecutionMem) + assert(mm.acquireExecutionMemory(1L) === 0L) + assert(mm.executionMemoryUsed === maxExecutionMem) + mm.releaseExecutionMemory(800L) + assert(mm.executionMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireExecutionMemory(1L) === 1L) + assert(mm.executionMemoryUsed === 201L) + // Release beyond what was acquired + mm.releaseExecutionMemory(maxExecutionMem) + assert(mm.executionMemoryUsed === 0L) + } + + test("basic storage memory") { + val maxStorageMem = 1000L + val dummyBlock = TestBlockId("you can see the world you brought to live") + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + // `ensureFreeSpace` should be called with the number of bytes requested + assertEnsureFreeSpaceCalled(ms, dummyBlock, 10L) + assert(mm.storageMemoryUsed === 10L) + assert(evictedBlocks.isEmpty) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) + assert(mm.storageMemoryUsed === 110L) + // Acquire up to the max, not granted + assert(!mm.acquireStorageMemory(dummyBlock, 1000L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1000L) + assert(mm.storageMemoryUsed === 110L) + assert(mm.acquireStorageMemory(dummyBlock, 890L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 890L) + assert(mm.storageMemoryUsed === 1000L) + assert(!mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assert(mm.storageMemoryUsed === 1000L) + mm.releaseStorageMemory(800L) + assert(mm.storageMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assert(mm.storageMemoryUsed === 201L) + mm.releaseStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assert(mm.storageMemoryUsed === 1L) + // Release beyond what was acquired + mm.releaseStorageMemory(100L) + assert(mm.storageMemoryUsed === 0L) + } + + test("execution and storage isolation") { + val maxExecutionMem = 200L + val maxStorageMem = 1000L + val dummyBlock = TestBlockId("ain't nobody love like you do") + val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) + // Only execution memory should increase + assert(mm.acquireExecutionMemory(100L) === 100L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 100L) + assert(mm.acquireExecutionMemory(1000L) === 100L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 200L) + // Only storage memory should increase + assert(mm.acquireStorageMemory(dummyBlock, 50L, dummyBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 50L) + assert(mm.storageMemoryUsed === 50L) + assert(mm.executionMemoryUsed === 200L) + // Only execution memory should be released + mm.releaseExecutionMemory(133L) + assert(mm.storageMemoryUsed === 50L) + assert(mm.executionMemoryUsed === 67L) + // Only storage memory should be released + mm.releaseStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 67L) + } + + test("unroll memory") { + val maxStorageMem = 1000L + val dummyBlock = TestBlockId("lonely water") + val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) + assert(mm.acquireUnrollMemory(dummyBlock, 100L, dummyBlocks)) + assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) + assert(mm.storageMemoryUsed === 100L) + mm.releaseUnrollMemory(40L) + assert(mm.storageMemoryUsed === 60L) + when(ms.currentUnrollMemory).thenReturn(60L) + assert(mm.acquireUnrollMemory(dummyBlock, 500L, dummyBlocks)) + // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. + // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. + assertEnsureFreeSpaceCalled(ms, dummyBlock, 340L) + assert(mm.storageMemoryUsed === 560L) + when(ms.currentUnrollMemory).thenReturn(560L) + assert(!mm.acquireUnrollMemory(dummyBlock, 800L, dummyBlocks)) + assert(mm.storageMemoryUsed === 560L) + // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed + assertEnsureFreeSpaceCalled(ms, dummyBlock, 0L) + // Release beyond what was acquired + mm.releaseUnrollMemory(maxStorageMem) + assert(mm.storageMemoryUsed === 0L) + } + + /** + * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. + */ + private def makeThings( + maxExecutionMem: Long, + maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { + val mm = new StaticMemoryManager( + conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem) + val ms = mock(classOf[MemoryStore]) + mm.setMemoryStore(ms) + (mm, ms) + } + + /** + * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + */ + private def assertEnsureFreeSpaceCalled( + ms: MemoryStore, + blockId: BlockId, + numBytes: Long): Unit = { + verify(ms).ensureFreeSpace(meq(blockId), meq(numBytes: java.lang.Long), any()) + reset(ms) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index eb5af70d57aec..cc44c676b27ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ +import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -39,29 +40,31 @@ import org.apache.spark.storage.StorageLevel._ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false).set("spark.app.id", "test") - var rpcEnv: RpcEnv = null - var master: BlockManagerMaster = null - val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) - val shuffleManager = new HashShuffleManager(conf) + private var rpcEnv: RpcEnv = null + private var master: BlockManagerMaster = null + private val securityMgr = new SecurityManager(conf) + private val mapOutputTracker = new MapOutputTrackerMaster(conf) + private val shuffleManager = new HashShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. - val allStores = new ArrayBuffer[BlockManager] + private val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer", "1m") - val serializer = new KryoSerializer(conf) + private val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. - implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val store = new BlockManager(name, rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(store.memoryStore) store.initialize("app-id") allStores += store store @@ -258,8 +261,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000) + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) + memManager.setMemoryStore(failableStore.memoryStore) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 34bb4952e7246..f3fab33ca2e31 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod +import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -67,10 +68,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) - manager.initialize("app-id") - manager + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(blockManager.memoryStore) + blockManager.initialize("app-id") + blockManager } override def beforeEach(): Unit = { @@ -820,9 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, - 0) + new JavaSerializer(conf), conf, memoryManager, mapOutputTracker, + shuffleManager, transfer, securityMgr, 0) + memoryManager.setMemoryStore(store.memoryStore) // The put should fail since a1 is not serializable. class UnserializableClass @@ -1043,14 +1048,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.currentUnrollMemory === 0) assert(memoryStore.currentUnrollMemoryForThisTask === 0) + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + memoryStore.reserveUnrollMemoryForThisTask( + TestBlockId(""), memory, new ArrayBuffer[(BlockId, BlockStatus)]) + } + // Reserve - memoryStore.reserveUnrollMemoryForThisTask(100) + assert(reserveUnrollMemoryForThisTask(100)) assert(memoryStore.currentUnrollMemoryForThisTask === 100) - memoryStore.reserveUnrollMemoryForThisTask(200) + assert(reserveUnrollMemoryForThisTask(200)) assert(memoryStore.currentUnrollMemoryForThisTask === 300) - memoryStore.reserveUnrollMemoryForThisTask(500) + assert(reserveUnrollMemoryForThisTask(500)) assert(memoryStore.currentUnrollMemoryForThisTask === 800) - memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(!reserveUnrollMemoryForThisTask(1000000)) assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release memoryStore.releaseUnrollMemoryForThisTask(100) @@ -1058,9 +1068,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE memoryStore.releaseUnrollMemoryForThisTask(100) assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(reserveUnrollMemoryForThisTask(4400)) assert(memoryStore.currentUnrollMemoryForThisTask === 5000) - memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(!reserveUnrollMemoryForThisTask(20000)) assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again memoryStore.releaseUnrollMemoryForThisTask(1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 48c3938ff87ba..ff65d7bdf8b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -17,12 +17,18 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + +import org.apache.spark.memory.MemoryManager import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.storage.{BlockId, BlockStatus} + /** * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. */ -class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { +class TestShuffleMemoryManager + extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) { private var oom = false override def tryToAcquire(numBytes: Long): Long = { @@ -49,3 +55,23 @@ class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1 oom = true } } + +private class GrantEverythingMemoryManager extends MemoryManager { + override def acquireExecutionMemory(numBytes: Long): Long = numBytes + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def releaseExecutionMemory(numBytes: Long): Unit = { } + override def releaseStorageMemory(numBytes: Long): Unit = { } + override def releaseStorageMemory(): Unit = { } + override def releaseUnrollMemory(numBytes: Long): Unit = { } + override def maxExecutionMemory: Long = Long.MaxValue + override def maxStorageMemory: Long = Long.MaxValue + override def executionMemoryUsed: Long = Long.MaxValue + override def storageMemoryUsed: Long = Long.MaxValue +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 13cfe29d7b304..b2b6848719639 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -253,12 +254,14 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) - manager.initialize("app-id") - blockManagerBuffer += manager - manager + val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(blockManager.memoryStore) + blockManager.initialize("app-id") + blockManagerBuffer += blockManager + blockManager } /** From 5410747a84e9be1cea44159dfc2216d5e0728ab4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Oct 2015 22:21:07 -0700 Subject: [PATCH 012/139] [SPARK-10959] [PYSPARK] StreamingLogisticRegressionWithSGD does not train with given regParam and convergenceTol parameters These params were being passed into the StreamingLogisticRegressionWithSGD constructor, but not transferred to the call for model training. Same with StreamingLinearRegressionWithSGD. I added the params as named arguments to the call and also fixed the intercept parameter, which was being passed as regularization value. Author: Bryan Cutler Closes #9002 from BryanCutler/StreamingSGD-convergenceTol-bug-10959. --- python/pyspark/mllib/classification.py | 3 ++- python/pyspark/mllib/regression.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cb4ee83678081..b77754500bded 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -639,7 +639,8 @@ def update(rdd): if not rdd.isEmpty(): self._model = LogisticRegressionWithSGD.train( rdd, self.numIterations, self.stepSize, - self.miniBatchFraction, self._model.weights) + self.miniBatchFraction, self._model.weights, + regParam=self.regParam, convergenceTol=self.convergenceTol) dstream.foreachRDD(update) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 256b7537fef6b..961b5e80b013c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -679,7 +679,7 @@ def update(rdd): self._model = LinearRegressionWithSGD.train( rdd, self.numIterations, self.stepSize, self.miniBatchFraction, self._model.weights, - self._model.intercept) + intercept=self._model.intercept, convergenceTol=self.convergenceTol) dstream.foreachRDD(update) From 5994cfe81271a39294aa29fd47aa94c99aa56743 Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Thu, 8 Oct 2015 22:22:20 -0700 Subject: [PATCH 013/139] [SPARK-10875] [MLLIB] Computed covariance matrix should be symmetric Compute upper triangular values of the covariance matrix, then copy to lower triangular values. Author: Nick Pritchard Closes #8940 from pnpritchard/SPARK-10875. --- .../mllib/linalg/distributed/RowMatrix.scala | 6 ++++-- .../linalg/distributed/RowMatrixSuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 7c7d900af3d5a..b8a7adceb15b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -357,9 +357,11 @@ class RowMatrix @Since("1.0.0") ( var alpha = 0.0 while (i < n) { alpha = m / m1 * mean(i) - j = 0 + j = i while (j < n) { - G(i, j) = G(i, j) / m1 - alpha * mean(j) + val Gij = G(i, j) / m1 - alpha * mean(j) + G(i, j) = Gij + G(j, i) = Gij j += 1 } i += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 283ffec1d49d7..4abb98fb6fe4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -24,6 +24,7 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -255,6 +256,23 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) } } + + test("compute covariance") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.computeCovariance() + val expected = breeze.linalg.cov(mat.toBreeze()) + assert(closeToZero(abs(expected) - abs(result.toBreeze.asInstanceOf[BDM[Double]]))) + } + } + + test("covariance matrix is symmetric (SPARK-10875)") { + val rdd = RandomRDDs.normalVectorRDD(sc, 100, 10, 0, 0) + val matrix = new RowMatrix(rdd) + val cov = matrix.computeCovariance() + for (i <- 0 until cov.numRows; j <- 0 until i) { + assert(cov(i, j) === cov(j, i)) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { From 70f44ad2d836236c74e1336a7368982d5fe3abff Mon Sep 17 00:00:00 2001 From: Rerngvit Yanggratoke Date: Fri, 9 Oct 2015 09:36:40 -0700 Subject: [PATCH 014/139] [SPARK-10905] [SPARKR] Export freqItems() for DataFrameStatFunctions [SPARK-10905][SparkR]: Export freqItems() for DataFrameStatFunctions - Add function (together with roxygen2 doc) to DataFrame.R and generics.R - Expose the function in NAMESPACE - Add unit test for the function Author: Rerngvit Yanggratoke Closes #8962 from rerngvit/SPARK-10905. --- R/pkg/NAMESPACE | 1 + R/pkg/R/generics.R | 4 ++++ R/pkg/R/stats.R | 27 +++++++++++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 21 +++++++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9aad35469bbb7..255be2e76ff49 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -40,6 +40,7 @@ exportMethods("arrange", "fillna", "filter", "first", + "freqItems", "group_by", "groupBy", "head", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e9086fdbd18c6..c4474131804bb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -63,6 +63,10 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) +# @rdname statfunctions +# @export +setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 06382d55d086e..4928cf4d4367d 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -100,3 +100,30 @@ setMethod("corr", statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "corr", col1, col2, method) }) + +#' freqItems +#' +#' Finding frequent items for columns, possibly with false positives. +#' Using the frequent element count algorithm described in +#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' +#' @param x A SparkSQL DataFrame. +#' @param cols A vector column names to search frequent items in. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' Should be greater than 1e-4. Default support = 0.01. +#' @return a local R data.frame with the frequent items in each column +#' +#' @rdname statfunctions +#' @name freqItems +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' fi = freqItems(df, c("title", "gender")) +#' } +setMethod("freqItems", signature(x = "DataFrame", cols = "character"), + function(x, cols, support = 0.01) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index e85de2507085c..4804ecf177341 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1350,6 +1350,27 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) }) +test_that("freqItems() on a DataFrame", { + input <- 1:1000 + rdf <- data.frame(numbers = input, letters = as.character(input), + negDoubles = input * -1.0, stringsAsFactors = F) + rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + df <- createDataFrame(sqlContext, rdf) + multiColResults <- freqItems(df, c("numbers", "letters"), support=0.1) + expect_true(1 %in% multiColResults$numbers[[1]]) + expect_true("1" %in% multiColResults$letters[[1]]) + singleColResult <- freqItems(df, "negDoubles", support=0.1) + expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) + + l <- lapply(c(0:99), function(i) { + if (i %% 2 == 0) { list(1L, -1.0) } + else { list(i, i * -1.0) }}) + df <- createDataFrame(sqlContext, l, c("a", "b")) + result <- freqItems(df, c("a", "b"), 0.4) + expect_identical(result[[1]], list(list(1L, 99L))) + expect_identical(result[[2]], list(list(-1, -99))) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table Not Found: blah", retError), TRUE) From 015f7ef503d5544f79512b6333326749a1f0c48b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 9 Oct 2015 15:28:09 -0500 Subject: [PATCH 015/139] [SPARK-8673] [LAUNCHER] API and infrastructure for communicating with child apps. This change adds an API that encapsulates information about an app launched using the library. It also creates a socket-based communication layer for apps that are launched as child processes; the launching application listens for connections from launched apps, and once communication is established, the channel can be used to send updates to the launching app, or to send commands to the child app. The change also includes hooks for local, standalone/client and yarn masters. Author: Marcelo Vanzin Closes #7052 from vanzin/SPARK-8673. --- .../spark/launcher/LauncherBackend.scala | 119 ++++++ .../cluster/SparkDeploySchedulerBackend.scala | 35 +- .../spark/scheduler/local/LocalBackend.scala | 19 +- .../spark/launcher/SparkLauncherSuite.java | 39 +- core/src/test/resources/log4j.properties | 11 +- .../spark/launcher/LauncherBackendSuite.scala | 81 +++++ launcher/pom.xml | 5 + .../launcher/AbstractCommandBuilder.java | 38 +- .../spark/launcher/ChildProcAppHandle.java | 159 ++++++++ .../spark/launcher/LauncherConnection.java | 110 ++++++ .../spark/launcher/LauncherProtocol.java | 93 +++++ .../apache/spark/launcher/LauncherServer.java | 341 ++++++++++++++++++ .../spark/launcher/NamedThreadFactory.java | 40 ++ .../spark/launcher/OutputRedirector.java | 78 ++++ .../apache/spark/launcher/SparkAppHandle.java | 126 +++++++ .../apache/spark/launcher/SparkLauncher.java | 106 +++++- .../launcher/SparkSubmitCommandBuilder.java | 22 +- .../apache/spark/launcher/package-info.java | 38 +- .../org/apache/spark/launcher/BaseSuite.java | 32 ++ .../spark/launcher/LauncherServerSuite.java | 188 ++++++++++ .../SparkSubmitCommandBuilderSuite.java | 4 +- .../SparkSubmitOptionParserSuite.java | 2 +- launcher/src/test/resources/log4j.properties | 13 +- .../org/apache/spark/deploy/yarn/Client.scala | 43 ++- .../cluster/YarnClientSchedulerBackend.scala | 10 + yarn/src/test/resources/log4j.properties | 7 +- .../deploy/yarn/BaseYarnClusterSuite.scala | 127 ++++--- .../spark/deploy/yarn/YarnClusterSuite.scala | 76 +++- .../yarn/YarnShuffleIntegrationSuite.scala | 4 +- 29 files changed, 1820 insertions(+), 146 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala create mode 100644 core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala create mode 100644 launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java create mode 100644 launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java create mode 100644 launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala new file mode 100644 index 0000000000000..3ea984c501e02 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.net.{InetAddress, Socket} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.launcher.LauncherProtocol._ +import org.apache.spark.util.ThreadUtils + +/** + * A class that can be used to talk to a launcher server. Users should extend this class to + * provide implementation for the abstract methods. + * + * See `LauncherServer` for an explanation of how launcher communication works. + */ +private[spark] abstract class LauncherBackend { + + private var clientThread: Thread = _ + private var connection: BackendConnection = _ + private var lastState: SparkAppHandle.State = _ + @volatile private var _isConnected = false + + def connect(): Unit = { + val port = sys.env.get(LauncherProtocol.ENV_LAUNCHER_PORT).map(_.toInt) + val secret = sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET) + if (port != None && secret != None) { + val s = new Socket(InetAddress.getLoopbackAddress(), port.get) + connection = new BackendConnection(s) + connection.send(new Hello(secret.get, SPARK_VERSION)) + clientThread = LauncherBackend.threadFactory.newThread(connection) + clientThread.start() + _isConnected = true + } + } + + def close(): Unit = { + if (connection != null) { + try { + connection.close() + } finally { + if (clientThread != null) { + clientThread.join() + } + } + } + } + + def setAppId(appId: String): Unit = { + if (connection != null) { + connection.send(new SetAppId(appId)) + } + } + + def setState(state: SparkAppHandle.State): Unit = { + if (connection != null && lastState != state) { + connection.send(new SetState(state)) + lastState = state + } + } + + /** Return whether the launcher handle is still connected to this backend. */ + def isConnected(): Boolean = _isConnected + + /** + * Implementations should provide this method, which should try to stop the application + * as gracefully as possible. + */ + protected def onStopRequest(): Unit + + /** + * Callback for when the launcher handle disconnects from this backend. + */ + protected def onDisconnected() : Unit = { } + + + private class BackendConnection(s: Socket) extends LauncherConnection(s) { + + override protected def handle(m: Message): Unit = m match { + case _: Stop => + onStopRequest() + + case _ => + throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") + } + + override def close(): Unit = { + try { + super.close() + } finally { + onDisconnected() + _isConnected = false + } + } + + } + +} + +private object LauncherBackend { + + val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend") + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 27491ecf8b97d..2625c3e7ac718 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,6 +23,7 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -36,6 +37,9 @@ private[spark] class SparkDeploySchedulerBackend( private var client: AppClient = null private var stopping = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ @volatile private var appId: String = _ @@ -47,6 +51,7 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() + launcherBackend.connect() // The endpoint for executors to talk to us val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, @@ -87,24 +92,20 @@ private[spark] class SparkDeploySchedulerBackend( command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) waitForRegistration() + launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop() { - stopping = true - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) - } + override def stop(): Unit = synchronized { + stop(SparkAppHandle.State.FINISHED) } override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) this.appId = appId notifyContext() + launcherBackend.setAppId(appId) } override def disconnected() { @@ -117,6 +118,7 @@ private[spark] class SparkDeploySchedulerBackend( override def dead(reason: String) { notifyContext() if (!stopping) { + launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) @@ -188,4 +190,19 @@ private[spark] class SparkDeploySchedulerBackend( registrationBarrier.release() } + private def stop(finalState: SparkAppHandle.State): Unit = synchronized { + stopping = true + + launcherBackend.setState(finalState) + launcherBackend.close() + + super.stop() + client.stop() + + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 4d48fcfea44e7..c633d860ae6e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,6 +24,7 @@ import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -103,6 +104,9 @@ private[spark] class LocalBackend( private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) private val listenerBus = scheduler.sc.listenerBus + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } /** * Returns a list of URLs representing the user classpath. @@ -114,6 +118,8 @@ private[spark] class LocalBackend( userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) } + launcherBackend.connect() + override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) @@ -122,10 +128,12 @@ private[spark] class LocalBackend( System.currentTimeMillis, executorEndpoint.localExecutorId, new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def stop() { - localEndpoint.ask(StopExecutor) + stop(SparkAppHandle.State.FINISHED) } override def reviveOffers() { @@ -145,4 +153,13 @@ private[spark] class LocalBackend( override def applicationId(): String = appId + private def stop(finalState: SparkAppHandle.State): Unit = { + localEndpoint.ask(StopExecutor) + try { + launcherBackend.setState(finalState) + } finally { + launcherBackend.close() + } + } + } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index d0c26dd05679b..aa15e792e2b27 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.bridge.SLF4JBridgeHandler; import static org.junit.Assert.*; /** @@ -34,7 +35,13 @@ */ public class SparkLauncherSuite { + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d"); @Test public void testSparkArgumentHandling() throws Exception { @@ -94,14 +101,15 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.name=-testChildProcLauncher") + "-Dfoo=bar -Dtest.appender=childproc") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); - new Redirector("stdout", app.getInputStream()).start(); - new Redirector("stderr", app.getErrorStream()).start(); + + new OutputRedirector(app.getInputStream(), TF); + new OutputRedirector(app.getErrorStream(), TF); assertEquals(0, app.waitFor()); } @@ -116,29 +124,4 @@ public static void main(String[] args) throws Exception { } - private static class Redirector extends Thread { - - private final InputStream in; - - Redirector(String name, InputStream in) { - this.in = in; - setName(name); - setDaemon(true); - } - - @Override - public void run() { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - String line; - while ((line = reader.readLine()) != null) { - LOG.warn(line); - } - } catch (Exception e) { - LOG.error("Error reading process output.", e); - } - } - - } - } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index eb3b1999eb996..a54d27de91ed2 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -16,13 +16,22 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +# Tests that launch java subprocesses can set the "test.appender" system property to +# "console" to avoid having the child process's logs overwrite the unit test's +# log file. +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala new file mode 100644 index 0000000000000..07e8869833e95 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.launcher._ + +class LauncherBackendSuite extends SparkFunSuite with Matchers { + + private val tests = Seq( + "local" -> "local", + "standalone/client" -> "local-cluster[1,1,1024]") + + tests.foreach { case (name, master) => + test(s"$name: launcher handle") { + testWithMaster(master) + } + } + + private def testWithMaster(master: String): Unit = { + val env = new java.util.HashMap[String, String]() + env.put("SPARK_PRINT_LAUNCH_COMMAND", "1") + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .setConf("spark.ui.enabled", "false") + .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console") + .setMaster(master) + .setAppResource("spark-internal") + .setMainClass(TestApp.getClass.getName().stripSuffix("$")) + .startApplication() + + try { + eventually(timeout(10 seconds), interval(100 millis)) { + handle.getAppId() should not be (null) + } + + handle.stop() + + eventually(timeout(10 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + +} + +object TestApp { + + def main(args: Array[String]): Unit = { + new SparkContext(new SparkConf()).parallelize(Seq(1)).foreach { i => + Thread.sleep(TimeUnit.SECONDS.toMillis(20)) + } + } + +} diff --git a/launcher/pom.xml b/launcher/pom.xml index d595d74642ab2..5739bfc16958f 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -47,6 +47,11 @@ mockito-core test + + org.slf4j + jul-to-slf4j + test + org.slf4j slf4j-api diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 610e8bdaaa639..cf3729b7febc3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -47,7 +47,7 @@ abstract class AbstractCommandBuilder { String javaHome; String mainClass; String master; - String propertiesFile; + protected String propertiesFile; final List appArgs; final List jars; final List files; @@ -55,6 +55,10 @@ abstract class AbstractCommandBuilder { final Map childEnv; final Map conf; + // The merged configuration for the application. Cached to avoid having to read / parse + // properties files multiple times. + private Map effectiveConfig; + public AbstractCommandBuilder() { this.appArgs = new ArrayList(); this.childEnv = new HashMap(); @@ -257,12 +261,38 @@ String getSparkHome() { return path; } + String getenv(String key) { + return firstNonEmpty(childEnv.get(key), System.getenv(key)); + } + + void setPropertiesFile(String path) { + effectiveConfig = null; + this.propertiesFile = path; + } + + Map getEffectiveConfig() throws IOException { + if (effectiveConfig == null) { + if (propertiesFile == null) { + effectiveConfig = conf; + } else { + effectiveConfig = new HashMap<>(conf); + Properties p = loadPropertiesFile(); + for (String key : p.stringPropertyNames()) { + if (!effectiveConfig.containsKey(key)) { + effectiveConfig.put(key, p.getProperty(key)); + } + } + } + } + return effectiveConfig; + } + /** * Loads the configuration file for the application, if it exists. This is either the * user-specified properties file, or the spark-defaults.conf file under the Spark configuration * directory. */ - Properties loadPropertiesFile() throws IOException { + private Properties loadPropertiesFile() throws IOException { Properties props = new Properties(); File propsFile; if (propertiesFile != null) { @@ -294,10 +324,6 @@ Properties loadPropertiesFile() throws IOException { return props; } - String getenv(String key) { - return firstNonEmpty(childEnv.get(key), System.getenv(key)); - } - private String findAssembly() { String sparkHome = getSparkHome(); File libdir; diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java new file mode 100644 index 0000000000000..de50f14fbdc87 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Handle implementation for monitoring apps started as a child process. + */ +class ChildProcAppHandle implements SparkAppHandle { + + private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final ThreadFactory REDIRECTOR_FACTORY = + new NamedThreadFactory("launcher-proc-%d"); + + private final String secret; + private final LauncherServer server; + + private Process childProc; + private boolean disposed; + private LauncherConnection connection; + private List listeners; + private State state; + private String appId; + private OutputRedirector redirector; + + ChildProcAppHandle(String secret, LauncherServer server) { + this.secret = secret; + this.server = server; + this.state = State.UNKNOWN; + } + + @Override + public synchronized void addListener(Listener l) { + if (listeners == null) { + listeners = new ArrayList<>(); + } + listeners.add(l); + } + + @Override + public State getState() { + return state; + } + + @Override + public String getAppId() { + return appId; + } + + @Override + public void stop() { + CommandBuilderUtils.checkState(connection != null, "Application is still not connected."); + try { + connection.send(new LauncherProtocol.Stop()); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + @Override + public synchronized void disconnect() { + if (!disposed) { + disposed = true; + if (connection != null) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. + } + } + server.unregister(this); + if (redirector != null) { + redirector.stop(); + } + } + } + + @Override + public synchronized void kill() { + if (!disposed) { + disconnect(); + } + if (childProc != null) { + childProc.destroy(); + childProc = null; + } + } + + String getSecret() { + return secret; + } + + void setChildProc(Process childProc, String loggerName) { + this.childProc = childProc; + this.redirector = new OutputRedirector(childProc.getInputStream(), loggerName, + REDIRECTOR_FACTORY); + } + + void setConnection(LauncherConnection connection) { + this.connection = connection; + } + + LauncherServer getServer() { + return server; + } + + LauncherConnection getConnection() { + return connection; + } + + void setState(State s) { + if (!state.isFinal()) { + state = s; + fireEvent(false); + } else { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { state, s }); + } + } + + void setAppId(String appId) { + this.appId = appId; + fireEvent(true); + } + + private synchronized void fireEvent(boolean isInfoChanged) { + if (listeners != null) { + for (Listener l : listeners) { + if (isInfoChanged) { + l.infoChanged(this); + } else { + l.stateChanged(this); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java new file mode 100644 index 0000000000000..eec264909bbb6 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.net.Socket; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * Encapsulates a connection between a launcher server and client. This takes care of the + * communication (sending and receiving messages), while processing of messages is left for + * the implementations. + */ +abstract class LauncherConnection implements Closeable, Runnable { + + private static final Logger LOG = Logger.getLogger(LauncherConnection.class.getName()); + + private final Socket socket; + private final ObjectOutputStream out; + + private volatile boolean closed; + + LauncherConnection(Socket socket) throws IOException { + this.socket = socket; + this.out = new ObjectOutputStream(socket.getOutputStream()); + this.closed = false; + } + + protected abstract void handle(Message msg) throws IOException; + + @Override + public void run() { + try { + ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); + while (!closed) { + Message msg = (Message) in.readObject(); + handle(msg); + } + } catch (EOFException eof) { + // Remote side has closed the connection, just cleanup. + try { + close(); + } catch (Exception unused) { + // no-op. + } + } catch (Exception e) { + if (!closed) { + LOG.log(Level.WARNING, "Error in inbound message handling.", e); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + } + } + + protected synchronized void send(Message msg) throws IOException { + try { + CommandBuilderUtils.checkState(!closed, "Disconnected."); + out.writeObject(msg); + out.flush(); + } catch (IOException ioe) { + if (!closed) { + LOG.log(Level.WARNING, "Error when sending message.", ioe); + try { + close(); + } catch (Exception unused) { + // no-op. + } + } + throw ioe; + } + } + + @Override + public void close() throws IOException { + if (!closed) { + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } + } + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java new file mode 100644 index 0000000000000..50f136497ec1a --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.Socket; +import java.util.Map; + +/** + * Message definitions for the launcher communication protocol. These messages must remain + * backwards-compatible, so that the launcher can talk to older versions of Spark that support + * the protocol. + */ +final class LauncherProtocol { + + /** Environment variable where the server port is stored. */ + static final String ENV_LAUNCHER_PORT = "_SPARK_LAUNCHER_PORT"; + + /** Environment variable where the secret for connecting back to the server is stored. */ + static final String ENV_LAUNCHER_SECRET = "_SPARK_LAUNCHER_SECRET"; + + static class Message implements Serializable { + + } + + /** + * Hello message, sent from client to server. + */ + static class Hello extends Message { + + final String secret; + final String sparkVersion; + + Hello(String secret, String version) { + this.secret = secret; + this.sparkVersion = version; + } + + } + + /** + * SetAppId message, sent from client to server. + */ + static class SetAppId extends Message { + + final String appId; + + SetAppId(String appId) { + this.appId = appId; + } + + } + + /** + * SetState message, sent from client to server. + */ + static class SetState extends Message { + + final SparkAppHandle.State state; + + SetState(SparkAppHandle.State state) { + this.state = state; + } + + } + + /** + * Stop message, send from server to client to stop the application. + */ + static class Stop extends Message { + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java new file mode 100644 index 0000000000000..c5fd40816d62f --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.List; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +/** + * A server that listens locally for connections from client launched by the library. Each client + * has a secret that it needs to send to the server to identify itself and establish the session. + * + * I/O is currently blocking (one thread per client). Clients have a limited time to connect back + * to the server, otherwise the server will ignore the connection. + * + * === Architecture Overview === + * + * The launcher server is used when Spark apps are launched as separate processes than the calling + * app. It looks more or less like the following: + * + * ----------------------- ----------------------- + * | User App | spark-submit | Spark App | + * | | -------------------> | | + * | ------------| |------------- | + * | | | hello | | | + * | | L. Server |<----------------------| L. Backend | | + * | | | | | | + * | ------------- ----------------------- + * | | | ^ + * | v | | + * | -------------| | + * | | | | + * | | App Handle |<------------------------------ + * | | | + * ----------------------- + * + * The server is started on demand and remains active while there are active or outstanding clients, + * to avoid opening too many ports when multiple clients are launched. Each client is given a unique + * secret, and have a limited amount of time to connect back + * ({@link SparkLauncher#CHILD_CONNECTION_TIMEOUT}), at which point the server will throw away + * that client's state. A client is only allowed to connect back to the server once. + * + * The launcher server listens on the localhost only, so it doesn't need access controls (aside from + * the per-app secret) nor encryption. It thus requires that the launched app has a local process + * that communicates with the server. In cluster mode, this means that the client that launches the + * application must remain alive for the duration of the application (or until the app handle is + * disconnected). + */ +class LauncherServer implements Closeable { + + private static final Logger LOG = Logger.getLogger(LauncherServer.class.getName()); + private static final String THREAD_NAME_FMT = "LauncherServer-%d"; + private static final long DEFAULT_CONNECT_TIMEOUT = 10000L; + + /** For creating secrets used for communication with child processes. */ + private static final SecureRandom RND = new SecureRandom(); + + private static volatile LauncherServer serverInstance; + + /** + * Creates a handle for an app to be launched. This method will start a server if one hasn't been + * started yet. The server is shared for multiple handles, and once all handles are disposed of, + * the server is shut down. + */ + static synchronized ChildProcAppHandle newAppHandle() throws IOException { + LauncherServer server = serverInstance != null ? serverInstance : new LauncherServer(); + server.ref(); + serverInstance = server; + + String secret = server.createSecret(); + while (server.pending.containsKey(secret)) { + secret = server.createSecret(); + } + + return server.newAppHandle(secret); + } + + static LauncherServer getServerInstance() { + return serverInstance; + } + + private final AtomicLong refCount; + private final AtomicLong threadIds; + private final ConcurrentMap pending; + private final List clients; + private final ServerSocket server; + private final Thread serverThread; + private final ThreadFactory factory; + private final Timer timeoutTimer; + + private volatile boolean running; + + private LauncherServer() throws IOException { + this.refCount = new AtomicLong(0); + + ServerSocket server = new ServerSocket(); + try { + server.setReuseAddress(true); + server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + + this.clients = new ArrayList(); + this.threadIds = new AtomicLong(); + this.factory = new NamedThreadFactory(THREAD_NAME_FMT); + this.pending = new ConcurrentHashMap<>(); + this.timeoutTimer = new Timer("LauncherServer-TimeoutTimer", true); + this.server = server; + this.running = true; + + this.serverThread = factory.newThread(new Runnable() { + @Override + public void run() { + acceptConnections(); + } + }); + serverThread.start(); + } catch (IOException ioe) { + close(); + throw ioe; + } catch (Exception e) { + close(); + throw new IOException(e); + } + } + + /** + * Creates a new app handle. The handle will wait for an incoming connection for a configurable + * amount of time, and if one doesn't arrive, it will transition to an error state. + */ + ChildProcAppHandle newAppHandle(String secret) { + ChildProcAppHandle handle = new ChildProcAppHandle(secret, this); + ChildProcAppHandle existing = pending.putIfAbsent(secret, handle); + CommandBuilderUtils.checkState(existing == null, "Multiple handles with the same secret."); + return handle; + } + + @Override + public void close() throws IOException { + synchronized (this) { + if (running) { + running = false; + timeoutTimer.cancel(); + server.close(); + synchronized (clients) { + List copy = new ArrayList<>(clients); + clients.clear(); + for (ServerConnection client : copy) { + client.close(); + } + } + } + } + if (serverThread != null) { + try { + serverThread.join(); + } catch (InterruptedException ie) { + // no-op + } + } + } + + void ref() { + refCount.incrementAndGet(); + } + + void unref() { + synchronized(LauncherServer.class) { + if (refCount.decrementAndGet() == 0) { + try { + close(); + } catch (IOException ioe) { + // no-op. + } finally { + serverInstance = null; + } + } + } + } + + int getPort() { + return server.getLocalPort(); + } + + /** + * Removes the client handle from the pending list (in case it's still there), and unrefs + * the server. + */ + void unregister(ChildProcAppHandle handle) { + pending.remove(handle.getSecret()); + unref(); + } + + private void acceptConnections() { + try { + while (running) { + final Socket client = server.accept(); + TimerTask timeout = new TimerTask() { + @Override + public void run() { + LOG.warning("Timed out waiting for hello message from client."); + try { + client.close(); + } catch (IOException ioe) { + // no-op. + } + } + }; + ServerConnection clientConnection = new ServerConnection(client, timeout); + Thread clientThread = factory.newThread(clientConnection); + synchronized (timeout) { + clientThread.start(); + synchronized (clients) { + clients.add(clientConnection); + } + timeoutTimer.schedule(timeout, getConnectionTimeout()); + } + } + } catch (IOException ioe) { + if (running) { + LOG.log(Level.SEVERE, "Error in accept loop.", ioe); + } + } + } + + private long getConnectionTimeout() { + String value = SparkLauncher.launcherConfig.get(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + return (value != null) ? Long.parseLong(value) : DEFAULT_CONNECT_TIMEOUT; + } + + private String createSecret() { + byte[] secret = new byte[128]; + RND.nextBytes(secret); + + StringBuilder sb = new StringBuilder(); + for (byte b : secret) { + int ival = b >= 0 ? b : Byte.MAX_VALUE - b; + if (ival < 0x10) { + sb.append("0"); + } + sb.append(Integer.toHexString(ival)); + } + return sb.toString(); + } + + private class ServerConnection extends LauncherConnection { + + private TimerTask timeout; + private ChildProcAppHandle handle; + + ServerConnection(Socket socket, TimerTask timeout) throws IOException { + super(socket); + this.timeout = timeout; + } + + @Override + protected void handle(Message msg) throws IOException { + try { + if (msg instanceof Hello) { + synchronized (timeout) { + timeout.cancel(); + } + timeout = null; + Hello hello = (Hello) msg; + ChildProcAppHandle handle = pending.remove(hello.secret); + if (handle != null) { + handle.setState(SparkAppHandle.State.CONNECTED); + handle.setConnection(this); + this.handle = handle; + } else { + throw new IllegalArgumentException("Received Hello for unknown client."); + } + } else { + if (handle == null) { + throw new IllegalArgumentException("Expected hello, got: " + + msg != null ? msg.getClass().getName() : null); + } + if (msg instanceof SetAppId) { + SetAppId set = (SetAppId) msg; + handle.setAppId(set.appId); + } else if (msg instanceof SetState) { + handle.setState(((SetState)msg).state); + } else { + throw new IllegalArgumentException("Invalid message: " + + msg != null ? msg.getClass().getName() : null); + } + } + } catch (Exception e) { + LOG.log(Level.INFO, "Error handling message from client.", e); + if (timeout != null) { + timeout.cancel(); + } + close(); + } finally { + timeoutTimer.purge(); + } + } + + @Override + public void close() throws IOException { + synchronized (clients) { + clients.remove(this); + } + super.close(); + if (handle != null) { + handle.disconnect(); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java new file mode 100644 index 0000000000000..995f4d73daaaf --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/NamedThreadFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +class NamedThreadFactory implements ThreadFactory { + + private final String nameFormat; + private final AtomicLong threadIds; + + NamedThreadFactory(String nameFormat) { + this.nameFormat = nameFormat; + this.threadIds = new AtomicLong(); + } + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, String.format(nameFormat, threadIds.incrementAndGet())); + t.setDaemon(true); + return t; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java new file mode 100644 index 0000000000000..6e7120167d605 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.concurrent.ThreadFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Redirects lines read from a given input stream to a j.u.l.Logger (at INFO level). + */ +class OutputRedirector { + + private final BufferedReader reader; + private final Logger sink; + private final Thread thread; + + private volatile boolean active; + + OutputRedirector(InputStream in, ThreadFactory tf) { + this(in, OutputRedirector.class.getName(), tf); + } + + OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { + this.active = true; + this.reader = new BufferedReader(new InputStreamReader(in)); + this.thread = tf.newThread(new Runnable() { + @Override + public void run() { + redirect(); + } + }); + this.sink = Logger.getLogger(loggerName); + thread.start(); + } + + private void redirect() { + try { + String line; + while ((line = reader.readLine()) != null) { + if (active) { + sink.info(line.replaceFirst("\\s*$", "")); + } + } + } catch (IOException e) { + sink.log(Level.FINE, "Error reading child process output.", e); + } + } + + /** + * This method just stops the output of the process from showing up in the local logs. + * The child's output will still be read (and, thus, the redirect thread will still be + * alive) to avoid the child process hanging because of lack of output buffer. + */ + void stop() { + active = false; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java new file mode 100644 index 0000000000000..2896a91d5e793 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +/** + * A handle to a running Spark application. + *

+ * Provides runtime information about the underlying Spark application, and actions to control it. + * + * @since 1.6.0 + */ +public interface SparkAppHandle { + + /** + * Represents the application's state. A state can be "final", in which case it will not change + * after it's reached, and means the application is not running anymore. + * + * @since 1.6.0 + */ + public enum State { + /** The application has not reported back yet. */ + UNKNOWN(false), + /** The application has connected to the handle. */ + CONNECTED(false), + /** The application has been submitted to the cluster. */ + SUBMITTED(false), + /** The application is running. */ + RUNNING(false), + /** The application finished with a successful status. */ + FINISHED(true), + /** The application finished with a failed status. */ + FAILED(true), + /** The application was killed. */ + KILLED(true); + + private final boolean isFinal; + + State(boolean isFinal) { + this.isFinal = isFinal; + } + + /** + * Whether this state is a final state, meaning the application is not running anymore + * once it's reached. + */ + public boolean isFinal() { + return isFinal; + } + } + + /** + * Adds a listener to be notified of changes to the handle's information. Listeners will be called + * from the thread processing updates from the application, so they should avoid blocking or + * long-running operations. + * + * @param l Listener to add. + */ + void addListener(Listener l); + + /** Returns the current application state. */ + State getState(); + + /** Returns the application ID, or null if not yet known. */ + String getAppId(); + + /** + * Asks the application to stop. This is best-effort, since the application may fail to receive + * or act on the command. Callers should watch for a state transition that indicates the + * application has really stopped. + */ + void stop(); + + /** + * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send + * a {@link #stop()} message to the application, so it's recommended that users first try to + * stop the application cleanly and only resort to this method if that fails. + */ + void kill(); + + /** + * Disconnects the handle from the application, without stopping it. After this method is called, + * the handle will not be able to communicate with the application anymore. + */ + void disconnect(); + + /** + * Listener for updates to a handle's state. The callbacks do not receive information about + * what exactly has changed, just that an update has occurred. + * + * @since 1.6.0 + */ + public interface Listener { + + /** + * Callback for changes in the handle's state. + * + * @param handle The updated handle. + * @see {@link SparkAppHandle#getState()} + */ + void stateChanged(SparkAppHandle handle); + + /** + * Callback for changes in any information that is not the handle's state. + * + * @param handle The updated handle. + */ + void infoChanged(SparkAppHandle handle); + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 57993405e47be..5d74b37033a51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -21,8 +21,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -58,6 +60,33 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; + /** Logger name to use when launching a child process. */ + public static final String CHILD_PROCESS_LOGGER_NAME = "spark.launcher.childProcLoggerName"; + + /** + * Maximum time (in ms) to wait for a child process to connect back to the launcher server + * when using @link{#start()}. + */ + public static final String CHILD_CONNECTION_TIMEOUT = "spark.launcher.childConectionTimeout"; + + /** Used internally to create unique logger names. */ + private static final AtomicInteger COUNTER = new AtomicInteger(); + + static final Map launcherConfig = new HashMap(); + + /** + * Set a configuration value for the launcher library. These config values do not affect the + * launched application, but rather the behavior of the launcher library itself when managing + * applications. + * + * @since 1.6.0 + * @param name Config name. + * @param value Config value. + */ + public static void setConfig(String name, String value) { + launcherConfig.put(name, value); + } + // Visible for testing. final SparkSubmitCommandBuilder builder; @@ -109,7 +138,7 @@ public SparkLauncher setSparkHome(String sparkHome) { */ public SparkLauncher setPropertiesFile(String path) { checkNotNull(path, "path"); - builder.propertiesFile = path; + builder.setPropertiesFile(path); return this; } @@ -197,6 +226,7 @@ public SparkLauncher setMainClass(String mainClass) { * Use this method with caution. It is possible to create an invalid Spark command by passing * unknown arguments to this method, since those are allowed for forward compatibility. * + * @since 1.5.0 * @param arg Argument to add. * @return This launcher. */ @@ -218,6 +248,7 @@ public SparkLauncher addSparkArg(String arg) { * Use this method with caution. It is possible to create an invalid Spark command by passing * unknown arguments to this method, since those are allowed for forward compatibility. * + * @since 1.5.0 * @param name Name of argument to add. * @param value Value of the argument. * @return This launcher. @@ -319,10 +350,81 @@ public SparkLauncher setVerbose(boolean verbose) { /** * Launches a sub-process that will start the configured Spark application. + *

+ * The {@link #startApplication(SparkAppHandle.Listener...)} method is preferred when launching + * Spark, since it provides better control of the child application. * * @return A process handle for the Spark app. */ public Process launch() throws IOException { + return createBuilder().start(); + } + + /** + * Starts a Spark application. + *

+ * This method returns a handle that provides information about the running application and can + * be used to do basic interaction with it. + *

+ * The returned handle assumes that the application will instantiate a single SparkContext + * during its lifetime. Once that context reports a final state (one that indicates the + * SparkContext has stopped), the handle will not perform new state transitions, so anything + * that happens after that cannot be monitored. If the underlying application is launched as + * a child process, {@link SparkAppHandle#kill()} can still be used to kill the child process. + *

+ * Currently, all applications are launched as child processes. The child's stdout and stderr + * are merged and written to a logger (see java.util.logging). The logger's name + * can be defined by setting {@link #CHILD_PROCESS_LOGGER_NAME} in the app's configuration. If + * that option is not set, the code will try to derive a name from the application's name or + * main class / script file. If those cannot be determined, an internal, unique name will be + * used. In all cases, the logger name will start with "org.apache.spark.launcher.app", to fit + * more easily into the configuration of commonly-used logging systems. + * + * @since 1.6.0 + * @param listeners Listeners to add to the handle before the app is launched. + * @return A handle for the launched application. + */ + public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) throws IOException { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + for (SparkAppHandle.Listener l : listeners) { + handle.addListener(l); + } + + String appName = builder.getEffectiveConfig().get(CHILD_PROCESS_LOGGER_NAME); + if (appName == null) { + if (builder.appName != null) { + appName = builder.appName; + } else if (builder.mainClass != null) { + int dot = builder.mainClass.lastIndexOf("."); + if (dot >= 0 && dot < builder.mainClass.length() - 1) { + appName = builder.mainClass.substring(dot + 1, builder.mainClass.length()); + } else { + appName = builder.mainClass; + } + } else if (builder.appResource != null) { + appName = new File(builder.appResource).getName(); + } else { + appName = String.valueOf(COUNTER.incrementAndGet()); + } + } + + String loggerPrefix = getClass().getPackage().getName(); + String loggerName = String.format("%s.app.%s", loggerPrefix, appName); + ProcessBuilder pb = createBuilder().redirectErrorStream(true); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_PORT, + String.valueOf(LauncherServer.getServerInstance().getPort())); + pb.environment().put(LauncherProtocol.ENV_LAUNCHER_SECRET, handle.getSecret()); + try { + handle.setChildProc(pb.start(), loggerName); + } catch (IOException ioe) { + handle.kill(); + throw ioe; + } + + return handle; + } + + private ProcessBuilder createBuilder() { List cmd = new ArrayList(); String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); @@ -343,7 +445,7 @@ public Process launch() throws IOException { for (Map.Entry e : builder.childEnv.entrySet()) { pb.environment().put(e.getKey(), e.getValue()); } - return pb.start(); + return pb; } private static class ArgumentValidator extends SparkSubmitOptionParser { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index fc87814a59ed5..39b46e0db8cc2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -188,10 +188,9 @@ private List buildSparkSubmitCommand(Map env) throws IOE // Load the properties file and check whether spark-submit will be running the app's driver // or just launching a cluster app. When running the driver, the JVM's argument will be // modified to cover the driver's configuration. - Properties props = loadPropertiesFile(); - boolean isClientMode = isClientMode(props); - String extraClassPath = isClientMode ? - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; + Map config = getEffectiveConfig(); + boolean isClientMode = isClientMode(config); + String extraClassPath = isClientMode ? config.get(SparkLauncher.DRIVER_EXTRA_CLASSPATH) : null; List cmd = buildJavaCommand(extraClassPath); // Take Thrift Server as daemon @@ -212,14 +211,13 @@ private List buildSparkSubmitCommand(Map env) throws IOE // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; - String memory = firstNonEmpty(tsMemory, - firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); - addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props)); + addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } addPermGenSizeOpt(cmd); @@ -281,9 +279,8 @@ private List buildSparkRCommand(Map env) throws IOExcept private void constructEnvVarArgs( Map env, String submitArgsEnvVariable) throws IOException { - Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { @@ -295,9 +292,8 @@ private void constructEnvVarArgs( env.put(submitArgsEnvVariable, submitArgs.toString()); } - - private boolean isClientMode(Properties userProps) { - String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); + private boolean isClientMode(Map userProps) { + String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. return userMaster == null || "client".equals(deployMode) || diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7c97dba511b28..d1ac39bdc76a9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,17 +17,42 @@ /** * Library for launching Spark applications. - * + * *

* This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. *

* *

- * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} - * and configure the application to run. For example: + * The {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} can be used to start Spark and provide + * a handle to monitor and control the running application: *

- * + * + *
+ * {@code
+ *   import org.apache.spark.launcher.SparkAppHandle;
+ *   import org.apache.spark.launcher.SparkLauncher;
+ *
+ *   public class MyLauncher {
+ *     public static void main(String[] args) throws Exception {
+ *       SparkAppHandle handle = new SparkLauncher()
+ *         .setAppResource("/my/app.jar")
+ *         .setMainClass("my.spark.app.Main")
+ *         .setMaster("local")
+ *         .setConf(SparkLauncher.DRIVER_MEMORY, "2g")
+ *         .startApplication();
+ *       // Use handle API to monitor / control application.
+ *     }
+ *   }
+ * }
+ * 
+ * + *

+ * It's also possible to launch a raw child process, using the + * {@link org.apache.spark.launcher.SparkLauncher#launch()} method: + *

+ * *
  * {@code
  *   import org.apache.spark.launcher.SparkLauncher;
@@ -45,5 +70,10 @@
  *   }
  * }
  * 
+ * + *

This method requires the calling code to manually manage the child process, including its + * output streams (to avoid possible deadlocks). It's recommended that + * {@link org.apache.spark.launcher.SparkLauncher#startApplication( + * org.apache.spark.launcher.SparkAppHandle.Listener...)} be used instead.

*/ package org.apache.spark.launcher; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java new file mode 100644 index 0000000000000..23e2c64d6dcd7 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import org.slf4j.bridge.SLF4JBridgeHandler; + +/** + * Handles configuring the JUL -> SLF4J bridge. + */ +class BaseSuite { + + static { + SLF4JBridgeHandler.removeHandlersForRootLogger(); + SLF4JBridgeHandler.install(); + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java new file mode 100644 index 0000000000000..27cd1061a15b3 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import static org.apache.spark.launcher.LauncherProtocol.*; + +public class LauncherServerSuite extends BaseSuite { + + @Test + public void testLauncherServerReuse() throws Exception { + ChildProcAppHandle handle1 = null; + ChildProcAppHandle handle2 = null; + ChildProcAppHandle handle3 = null; + + try { + handle1 = LauncherServer.newAppHandle(); + handle2 = LauncherServer.newAppHandle(); + LauncherServer server1 = handle1.getServer(); + assertSame(server1, handle2.getServer()); + + handle1.kill(); + handle2.kill(); + + handle3 = LauncherServer.newAppHandle(); + assertNotSame(server1, handle3.getServer()); + + handle3.kill(); + + assertNull(LauncherServer.getServerInstance()); + } finally { + kill(handle1); + kill(handle2); + kill(handle3); + } + } + + @Test + public void testCommunication() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + + final Object waitLock = new Object(); + handle.addListener(new SparkAppHandle.Listener() { + @Override + public void stateChanged(SparkAppHandle handle) { + wakeUp(); + } + + @Override + public void infoChanged(SparkAppHandle handle) { + wakeUp(); + } + + private void wakeUp() { + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + }); + + client = new TestClient(s); + synchronized (waitLock) { + client.send(new Hello(handle.getSecret(), "1.4.0")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + + // Make sure the server matched the client to the handle. + assertNotNull(handle.getConnection()); + + synchronized (waitLock) { + client.send(new SetAppId("app-id")); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals("app-id", handle.getAppId()); + + synchronized (waitLock) { + client.send(new SetState(SparkAppHandle.State.RUNNING)); + waitLock.wait(TimeUnit.SECONDS.toMillis(10)); + } + assertEquals(SparkAppHandle.State.RUNNING, handle.getState()); + + handle.stop(); + Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS); + assertTrue(stopMsg instanceof Stop); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + + @Test + public void testTimeout() throws Exception { + final long TEST_TIMEOUT = 10L; + + ChildProcAppHandle handle = null; + TestClient client = null; + try { + SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, String.valueOf(TEST_TIMEOUT)); + + handle = LauncherServer.newAppHandle(); + + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + client = new TestClient(s); + + Thread.sleep(TEST_TIMEOUT * 10); + try { + client.send(new Hello(handle.getSecret(), "1.4.0")); + fail("Expected exception caused by connection timeout."); + } catch (IllegalStateException e) { + // Expected. + } + } finally { + SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); + kill(handle); + close(client); + } + } + + private void kill(SparkAppHandle handle) { + if (handle != null) { + handle.kill(); + } + } + + private void close(Closeable c) { + if (c != null) { + try { + c.close(); + } catch (Exception e) { + // no-op. + } + } + } + + private static class TestClient extends LauncherConnection { + + final BlockingQueue inbound; + final Thread clientThread; + + TestClient(Socket s) throws IOException { + super(s); + this.inbound = new LinkedBlockingQueue(); + this.clientThread = new Thread(this); + clientThread.setName("TestClient"); + clientThread.setDaemon(true); + clientThread.start(); + } + + @Override + protected void handle(Message msg) throws IOException { + inbound.offer(msg); + } + + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 7329ac9f7fb8c..d5397b0685046 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -30,7 +30,7 @@ import org.junit.Test; import static org.junit.Assert.*; -public class SparkSubmitCommandBuilderSuite { +public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; @@ -161,7 +161,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { launcher.appResource = "/foo"; launcher.appName = "MyApp"; launcher.mainClass = "my.Class"; - launcher.propertiesFile = dummyPropsFile.getAbsolutePath(); + launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); launcher.appArgs.add("foo"); launcher.appArgs.add("bar"); launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java index f3d2109917056..3ee5b8cf9689d 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -28,7 +28,7 @@ import static org.apache.spark.launcher.SparkSubmitOptionParser.*; -public class SparkSubmitOptionParserSuite { +public class SparkSubmitOptionParserSuite extends BaseSuite { private SparkSubmitOptionParser parser; diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 67a6a98217118..c64b1565e1469 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -16,16 +16,19 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootCategory=INFO, file +test.appender=file +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false - -# Some tests will set "test.name" to avoid overwriting the main log file. -log4j.appender.file.file=target/unit-tests${test.name}.log - +log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n +log4j.appender.childproc=org.apache.log4j.ConsoleAppender +log4j.appender.childproc.target=System.err +log4j.appender.childproc.layout=org.apache.log4j.PatternLayout +log4j.appender.childproc.layout.ConversionPattern=%t: %m%n + # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN org.spark-project.jetty.LEVEL=WARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index eb3b7fb885087..cec81b940644c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -55,8 +55,8 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils private[spark] class Client( @@ -70,8 +70,6 @@ private[spark] class Client( def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) - def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) private var credentials: Credentials = null @@ -84,10 +82,27 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = { + if (isClusterMode && appId != null) { + yarnClient.killApplication(appId) + } else { + setState(SparkAppHandle.State.KILLED) + stop() + } + } + } private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) + private var appId: ApplicationId = null + + def reportLauncherState(state: SparkAppHandle.State): Unit = { + launcherBackend.setState(state) + } + def stop(): Unit = { + launcherBackend.close() yarnClient.stop() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") @@ -103,6 +118,7 @@ private[spark] class Client( def submitApplication(): ApplicationId = { var appId: ApplicationId = null try { + launcherBackend.connect() // Setup the credentials before doing anything else, // so we have don't have issues at any point. setupCredentials() @@ -116,6 +132,8 @@ private[spark] class Client( val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() appId = newAppResponse.getApplicationId() + reportLauncherState(SparkAppHandle.State.SUBMITTED) + launcherBackend.setAppId(appId.toString()) // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) @@ -881,6 +899,20 @@ private[spark] class Client( } } + if (lastState != state) { + state match { + case YarnApplicationState.RUNNING => + reportLauncherState(SparkAppHandle.State.RUNNING) + case YarnApplicationState.FINISHED => + reportLauncherState(SparkAppHandle.State.FINISHED) + case YarnApplicationState.FAILED => + reportLauncherState(SparkAppHandle.State.FAILED) + case YarnApplicationState.KILLED => + reportLauncherState(SparkAppHandle.State.KILLED) + case _ => + } + } + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -928,8 +960,8 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - val appId = submitApplication() - if (fireAndForget) { + this.appId = submitApplication() + if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState logInfo(s"Application report for $appId (state: $state)") @@ -971,6 +1003,7 @@ private[spark] class Client( } object Client extends Logging { + def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { logWarning("WARNING: This client is deprecated and will be removed in a " + diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 36d5759554d98..20771f655473c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -177,6 +178,15 @@ private[spark] class YarnClientSchedulerBackend( if (monitorThread != null) { monitorThread.stopMonitor() } + + // Report a final state to the launcher if one is connected. This is needed since in client + // mode this backend doesn't let the app monitor loop run to completion, so it does not report + // the final state itself. + // + // Note: there's not enough information at this point to provide a better final state, + // so assume the application was successful. + client.reportLauncherState(SparkAppHandle.State.FINISHED) + super.stop() YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() client.stop() diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b8a5dbf6373e..6b9a799954bf1 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -23,6 +23,9 @@ log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 17c59ff06e0c1..12494b01054ba 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -22,15 +22,18 @@ import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.util.Utils abstract class BaseYarnClusterSuite @@ -46,13 +49,14 @@ abstract class BaseYarnClusterSuite |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.mortbay=WARN |log4j.logger.org.spark-project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ protected var tempDir: File = _ private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ + protected var hadoopConfDir: File = _ private var logConfDir: File = _ def newYarnConfig(): YarnConfiguration @@ -120,15 +124,77 @@ abstract class BaseYarnClusterSuite clientMode: Boolean, klass: String, appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, + sparkArgs: Seq[(String, String)] = Nil, extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): Unit = { + extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() + val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) + val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv + + val launcher = new SparkLauncher(env.asJava) + if (klass.endsWith(".py")) { + launcher.setAppResource(klass) + } else { + launcher.setMainClass(klass) + launcher.setAppResource(fakeSparkJar.getAbsolutePath()) + } + launcher.setSparkHome(sys.props("spark.test.home")) + .setMaster(master) + .setConf("spark.executor.instances", "1") + .setPropertiesFile(propsFile) + .addAppArgs(appArgs.toArray: _*) + + sparkArgs.foreach { case (name, value) => + if (value != null) { + launcher.addSparkArg(name, value) + } else { + launcher.addSparkArg(name) + } + } + extraJars.foreach(launcher.addJar) - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + val handle = launcher.startApplication() + try { + eventually(timeout(2 minutes), interval(1 second)) { + assert(handle.getState().isFinal()) + } + } finally { + handle.kill() + } + + handle.getState() + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { + checkResult(finalState, result, "success") + } + + protected def checkResult( + finalState: SparkAppHandle.State, + result: File, + expected: String): Unit = { + finalState should be (SparkAppHandle.State.FINISHED) + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + + protected def createConfFile( + extraClassPath: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): String = { + val props = new Properties() + props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) val testClasspath = new TestClasspathBuilder() .buildClassPath( @@ -138,69 +204,28 @@ abstract class BaseYarnClusterSuite .asScala .mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", testClasspath) - props.setProperty("spark.executor.extraClassPath", testClasspath) + props.put("spark.driver.extraClassPath", testClasspath) + props.put("spark.executor.extraClassPath", testClasspath) // SPARK-4267: make sure java options are propagated correctly. props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - yarnCluster.getConfig.asScala.foreach { e => + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } - sys.props.foreach { case (k, v) => if (k.startsWith("spark.")) { props.setProperty(k, v) } } - extraConf.foreach { case (k, v) => props.setProperty(k, v) } val propsFile = File.createTempFile("spark", ".properties", tempDir) val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) props.store(writer, "Spark properties.") writer.close() - - val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - protected def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - protected def checkResult(result: File, expected: String): Unit = { - val resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - protected def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") + propsFile.getAbsolutePath() } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index f1601cd16100f..d1cd0c89b5d38 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -19,16 +19,20 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URL +import java.util.{HashMap => JHashMap, Properties} import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -82,10 +86,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. - val exception = intercept[SparkException] { - runSpark(false, mainClassName(YarnClusterDriver.getClass)) - fail("Spark application should have failed.") - } + val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) + finalState should be (SparkAppHandle.State.FAILED) } test("run Python application in yarn-client mode") { @@ -104,11 +106,42 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testUseClassPathFirst(false) } + test("monitor app using launcher library") { + val env = new JHashMap[String, String]() + env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath()) + + val propsFile = createConfFile() + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf("spark.ui.enabled", "false") + .setPropertiesFile(propsFile) + .setMaster("yarn-client") + .setAppResource("spark-internal") + .setMainClass(mainClassName(YarnLauncherTestApp.getClass)) + .startApplication() + + try { + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.RUNNING) + } + + handle.getAppId() should not be (null) + handle.getAppId() should startWith ("application_") + handle.stop() + + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + private def testBasicYarnApp(clientMode: Boolean): Unit = { val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) - checkResult(result) + checkResult(finalState, result) } private def testPySpark(clientMode: Boolean): Unit = { @@ -143,11 +176,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFiles), + val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnv) - checkResult(result) + checkResult(finalState, result) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { @@ -156,15 +189,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) val driverResult = File.createTempFile("driver", null, tempDir) val executorResult = File.createTempFile("executor", null, tempDir) - runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), extraClassPath = Seq(originalJar.getPath()), extraJars = Seq("local:" + userJar.getPath()), extraConf = Map( "spark.driver.userClassPathFirst" -> "true", "spark.executor.userClassPathFirst" -> "true")) - checkResult(driverResult, "OVERRIDDEN") - checkResult(executorResult, "OVERRIDDEN") + checkResult(finalState, driverResult, "OVERRIDDEN") + checkResult(finalState, executorResult, "OVERRIDDEN") } } @@ -211,8 +244,8 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - sc.stop() Files.write(result, status, UTF_8) + sc.stop() } // verify log urls are present @@ -297,3 +330,18 @@ private object YarnClasspathTest extends Logging { } } + +private object YarnLauncherTestApp { + + def main(args: Array[String]): Unit = { + // Do not stop the application; the test will stop it using the launcher lib. Just run a task + // that will prevent the process from exiting. + val sc = new SparkContext(new SparkConf()) + sc.parallelize(Seq(1)).foreach { i => + this.synchronized { + wait() + } + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index a85e5772a0fa4..c17e8695c24fb 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -53,7 +53,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { logInfo("Shuffle service port = " + shuffleServicePort) val result = File.createTempFile("result", null, tempDir) - runSpark( + val finalState = runSpark( false, mainClassName(YarnExternalShuffleDriver.getClass), appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), @@ -62,7 +62,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { "spark.shuffle.service.port" -> shuffleServicePort.toString ) ) - checkResult(result) + checkResult(finalState, result) assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) } } From 12b7191d2075ae870c73529de450cbb5725872ec Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Fri, 9 Oct 2015 13:36:51 -0700 Subject: [PATCH 016/139] [SPARK-10855] [SQL] Add a JDBC dialect for Apache Derby marmbrus rxin This patch adds a JdbcDialect class, which customizes the datatype mappings for Derby backends. The patch also adds unit tests for the new dialect, corresponding to the existing tests for other JDBC dialects. JDBCSuite runs cleanly for me with this patch. So does JDBCWriteSuite, although it produces noise as described here: https://issues.apache.org/jira/browse/SPARK-10890 This patch is my original work, which I license to the ASF. I am a Derby contributor, so my ICLA is on file under SVN id "rhillegas": http://people.apache.org/committer-index.html Touches the following files: --------------------------------- org.apache.spark.sql.jdbc.JdbcDialects Adds a DerbyDialect. --------------------------------- org.apache.spark.sql.jdbc.JDBCSuite Adds unit tests for the new DerbyDialect. Author: Rick Hillegas Closes #8982 from rick-ibm/b_10855. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 28 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 +++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 0cd356f222984..a2ff4cc1c91f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -138,6 +138,7 @@ object JdbcDialects { registerDialect(PostgresDialect) registerDialect(DB2Dialect) registerDialect(MsSqlServerDialect) + registerDialect(DerbyDialect) /** @@ -287,3 +288,30 @@ case object MsSqlServerDialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Apache Derby dialect, mapping real on read + * and string/byte/short/boolean/decimal on write. + */ +@DeveloperApi +case object DerbyDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case (t: DecimalType) if (t.precision > 31) => + Some(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bbf705ce95933..d530b1a469ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -409,18 +409,22 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } test("quote column names by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val Derby = JdbcDialects.get("jdbc:derby:db") val columns = Seq("abc", "key") val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + val DerbyColumns = columns.map(Derby.quoteIdentifier(_)) assert(MySQLColumns === Seq("`abc`", "`key`")) assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } test("Dialect unregister") { @@ -454,16 +458,23 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("PostgresDialect type mapping") { val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - // SPARK-7869: Testing JSON types handling assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) } + test("DerbyDialect jdbc type mapping") { + val derbyDialect = JdbcDialects.get("jdbc:derby:db") + assert(derbyDialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(derbyDialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") + } + test("table exists query by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") val table = "weblogs" val defaultQuery = s"SELECT * FROM $table WHERE 1=0" val limitQuery = s"SELECT 1 FROM $table LIMIT 1" @@ -471,5 +482,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(Postgres.getTableExistsQuery(table) == limitQuery) assert(db2.getTableExistsQuery(table) == defaultQuery) assert(h2.getTableExistsQuery(table) == defaultQuery) + assert(derby.getTableExistsQuery(table) == defaultQuery) } } From 63c340a710b24869410d56602b712fbfe443e6f0 Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Fri, 9 Oct 2015 14:06:25 -0700 Subject: [PATCH 017/139] [SPARK-10858] YARN: archives/jar/files rename with # doesn't work unl https://issues.apache.org/jira/browse/SPARK-10858 The issue here is that in resolveURI we default to calling new File(path).getAbsoluteFile().toURI(). But if the path passed in already has a # in it then File(path) will think that is supposed to be part of the actual file path and not a fragment so it changes # to %23. Then when we try to parse that later in Client as a URI it doesn't recognize there is a fragment. so to fix we just check if there is a fragment, still create the File like we did before and then add the fragment back on. Author: Tom Graves Closes #9035 from tgravescs/SPARK-10858. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 7 +++++++ core/src/test/scala/org/apache/spark/util/UtilsSuite.scala | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2bab4af2e73ab..e60c1b355a73e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1749,6 +1749,13 @@ private[spark] object Utils extends Logging { if (uri.getScheme() != null) { return uri } + // make sure to handle if the path has a fragment (applies to yarn + // distributed cache) + if (uri.getFragment() != null) { + val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI() + return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(), + uri.getFragment()) + } } catch { case e: URISyntaxException => } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1fb81ad565b41..68b0da76bc134 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -384,7 +384,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("hdfs:/root/spark.jar", "hdfs:/root/spark.jar") assertResolves("hdfs:///root/spark.jar#app.jar", "hdfs:/root/spark.jar#app.jar") assertResolves("spark.jar", s"file:$cwd/spark.jar") - assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar%23app.jar") + assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar#app.jar") assertResolves("path to/file.txt", s"file:$cwd/path%20to/file.txt") if (Utils.isWindows) { assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt") @@ -414,10 +414,10 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("file:/jar1,file:/jar2", "file:/jar1,file:/jar2") assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5,path to/jar6", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4%23jar5,file:$cwd/path%20to/jar6") + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5,file:$cwd/path%20to/jar6") if (Utils.isWindows) { assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4""", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi,file:/C:/path%20to/jar4") } } From c1b4ce43264fa8b9945df3c599a51d4d2a675705 Mon Sep 17 00:00:00 2001 From: Vladimir Vladimirov Date: Fri, 9 Oct 2015 14:16:13 -0700 Subject: [PATCH 018/139] [SPARK-10535] Sync up API for matrix factorization model between Scala and PySpark Support for recommendUsersForProducts and recommendProductsForUsers in matrix factorization model for PySpark Author: Vladimir Vladimirov Closes #8700 from smartkiwi/SPARK-10535_. --- .../MatrixFactorizationModelWrapper.scala | 8 +++++ python/pyspark/mllib/recommendation.py | 32 ++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala index 534edac56bc5a..eeb7cba882ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -42,4 +42,12 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization case (product, feature) => (product, Vectors.dense(feature)) }.asInstanceOf[RDD[(Any, Any)]]) } + + def wrappedRecommendProductsForUsers(num: Int): RDD[Array[Any]] = { + SerDe.fromTuple2RDD(recommendProductsForUsers(num).asInstanceOf[RDD[(Any, Any)]]) + } + + def wrappedRecommendUsersForProducts(num: Int): RDD[Array[Any]] = { + SerDe.fromTuple2RDD(recommendUsersForProducts(num).asInstanceOf[RDD[(Any, Any)]]) + } } diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 95047b5b7b4b7..b9442b0d16c0f 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -76,16 +76,28 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 >>> model.productFeatures().collect() [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 + + >>> products_for_users = model.recommendProductsForUsers(1).collect() + >>> len(products_for_users) + 2 + >>> products_for_users[0] + (1, (Rating(user=1, product=2, rating=...),)) + + >>> users_for_products = model.recommendUsersForProducts(1).collect() + >>> len(users_for_products) + 2 + >>> users_for_products[0] + (1, (Rating(user=2, product=1, rating=...),)) >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) @@ -166,6 +178,18 @@ def recommendProducts(self, user, num): """ return list(self.call("recommendProducts", user, num)) + def recommendProductsForUsers(self, num): + """ + Recommends top "num" products for all users. The number returned may be less than this. + """ + return self.call("wrappedRecommendProductsForUsers", num) + + def recommendUsersForProducts(self, num): + """ + Recommends top "num" users for all products. The number returned may be less than this. + """ + return self.call("wrappedRecommendUsersForProducts", num) + @property @since("1.4.0") def rank(self): From 864de3bf4041c829e95d278b9569e91448bab0cc Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 9 Oct 2015 23:05:38 -0700 Subject: [PATCH 019/139] [SPARK-10079] [SPARKR] Make 'column' and 'col' functions be S4 functions. 1. Add a "col" function into DataFrame. 2. Move the current "col" function in Column.R to functions.R, convert it to S4 function. 3. Add a s4 "column" function in functions.R. 4. Convert the "column" function in Column.R to S4 function. This is for private use. Author: Sun Rui Closes #8864 from sun-rui/SPARK-10079. --- R/pkg/NAMESPACE | 1 + R/pkg/R/column.R | 12 +++++------- R/pkg/R/functions.R | 22 ++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 255be2e76ff49..95d949ee3e5a4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -107,6 +107,7 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "column", "concat", "concat_ws", "contains", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 42e9d12179db7..20de3907b7dd9 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -36,13 +36,11 @@ setMethod("initialize", "Column", function(.Object, jc) { .Object }) -column <- function(jc) { - new("Column", jc) -} - -col <- function(x) { - column(callJStatic("org.apache.spark.sql.functions", "col", x)) -} +setMethod("column", + signature(x = "jobj"), + function(x) { + new("Column", x) + }) #' @rdname show #' @name show diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 94687edb05442..a220ad8b9f58b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -233,6 +233,28 @@ setMethod("ceil", column(jc) }) +#' Though scala functions has "col" function, we don't expose it in SparkR +#' because we don't want to conflict with the "col" function in the R base +#' package and we also have "column" function exported which is an alias of "col". +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' column +#' +#' Returns a Column based on the given column name. +#' +#' @rdname col +#' @name column +#' @family normal_funcs +#' @export +#' @examples \dontrun{column(df)} +setMethod("column", + signature(x = "character"), + function(x) { + col(x) + }) + #' cos #' #' Computes the cosine of the given value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c4474131804bb..8fad17026c06f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -686,6 +686,10 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname col +#' @export +setGeneric("column", function(x) { standardGeneric("column") }) + #' @rdname concat #' @export setGeneric("concat", function(x, ...) { standardGeneric("concat") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 4804ecf177341..3a04edbb4c116 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -787,7 +787,7 @@ test_that("test HiveContext", { }) test_that("column operators", { - c <- SparkR:::col("a") + c <- column("a") c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) @@ -795,7 +795,7 @@ test_that("column operators", { }) test_that("column functions", { - c <- SparkR:::col("a") + c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + exp(c) From a16396df76cc27099011bfb96b28cbdd7f964ca8 Mon Sep 17 00:00:00 2001 From: Jacker Hu Date: Sat, 10 Oct 2015 11:36:18 +0100 Subject: [PATCH 020/139] [SPARK-10772] [STREAMING] [SCALA] NullPointerException when transform function in DStream returns NULL Currently, the ```TransformedDStream``` will using ```Some(transformFunc(parentRDDs, validTime))``` as compute return value, when the ```transformFunc``` somehow returns null as return value, the followed operator will have NullPointerExeception. This fix uses the ```Option()``` instead of ```Some()``` to deal with the possible null value. When ```transformFunc``` returns ```null```, the option will transform null to ```None```, the downstream can handle ```None``` correctly. NOTE (2015-09-25): The latest fix will check the return value of transform function, if it is ```NULL```, a spark exception will be thrown out Author: Jacker Hu Author: jhu-chang Closes #8881 from jhu-chang/Fix_Transform. --- .../streaming/dstream/TransformedDStream.scala | 12 ++++++++++-- .../spark/streaming/BasicOperationsSuite.scala | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 5d46ca0715ffd..ab01f47d5cf99 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + +import org.apache.spark.SparkException import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag private[streaming] class TransformedDStream[U: ClassTag] ( @@ -38,6 +40,12 @@ class TransformedDStream[U: ClassTag] ( override def compute(validTime: Time): Option[RDD[U]] = { val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq - Some(transformFunc(parentRDDs, validTime)) + val transformedRDD = transformFunc(parentRDDs, validTime) + if (transformedRDD == null) { + throw new SparkException("Transform function must not return null. " + + "Return SparkContext.emptyRDD() instead to represent no element " + + "as the result of transformation.") + } + Some(transformedRDD) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 255376807c957..9988f410f0bc1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -211,6 +211,19 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("transform with NULL") { + val input = Seq(1 to 4) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => null.asInstanceOf[RDD[Int]]), + Seq(Seq()), + 1, + false + ) + } + } + test("transformWith") { val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) From 595012ea8b9c6afcc2fc024d5a5e198df765bd75 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 11 Oct 2015 18:11:08 -0700 Subject: [PATCH 021/139] [SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator SortBasedAggregationIterator uses a KVIterator interface in order to process input rows as key-value pairs, but this use of KVIterator is unnecessary, slightly complicates the code, and might hurt performance. This patch refactors this code to remove the use of this extra layer of iterator wrapping and simplifies other parts of the code in the process. Author: Josh Rosen Closes #9066 from JoshRosen/sort-iterator-cleanup. --- .../aggregate/AggregationIterator.scala | 83 ----------------- .../aggregate/SortBasedAggregate.scala | 20 +++-- .../SortBasedAggregationIterator.scala | 89 +++++-------------- 3 files changed, 33 insertions(+), 159 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 5f7341e88c7c9..8e0fbd109b413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.unsafe.KVIterator import scala.collection.mutable.ArrayBuffer @@ -412,85 +411,3 @@ abstract class AggregationIterator( */ protected def newBuffer: MutableRow } - -object AggregationIterator { - def kvIterator( - groupingExpressions: Seq[NamedExpression], - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) - - private[this] var groupingKey: InternalRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): InternalRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } - - def unsafeKVIterator( - groupingExpressions: Seq[NamedExpression], - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { - new KVIterator[UnsafeRow, InternalRow] { - private[this] val groupingKeyGenerator = - UnsafeProjection.create(groupingExpressions, inputAttributes) - - private[this] var groupingKey: UnsafeRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator.apply(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index f4c14a9b3556f..4d37106e007f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -79,18 +78,23 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, + val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + UnsafeProjection.create(groupingExpressions, child.output) + } else { + newMutableProjection(groupingExpressions, child.output)() + } + val outputIter = new SortBasedAggregationIterator( + groupingKeyProjection, + groupingExpressions.map(_.toAttribute), + child.output, + iter, nonCompleteAggregateExpressions, nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection _, - newProjection _, - child.output, - iter, + newMutableProjection, outputsUnsafeRows, numInputRows, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a9e5d175bf895..64c673064f576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.unsafe.KVIterator /** * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( + groupingKeyProjection: InternalRow => InternalRow, groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], + inputIterator: Iterator[InternalRow], nonCompleteAggregateExpressions: Seq[AggregateExpression2], nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], @@ -90,6 +90,22 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingKeyProjection(inputRow).copy() + firstRowInNextGroup = inputRow.copy() + numInputRows += 1 + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -101,18 +117,15 @@ class SortBasedAggregationIterator( // The search will stop when we see the next group or there is no // input row left in the iter. - var hasNext = inputKVIterator.next() - while (!findNextPartition && hasNext) { + while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val groupingKey = inputKVIterator.getKey - val currentRow = inputKVIterator.getValue + val currentRow = inputIterator.next() + val groupingKey = groupingKeyProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - - hasNext = inputKVIterator.next() } else { // We find a new group. findNextPartition = true @@ -149,68 +162,8 @@ class SortBasedAggregationIterator( } } - protected def initialize(): Unit = { - if (inputKVIterator.next()) { - initializeBuffer(sortBasedAggregationBuffer) - - nextGroupingKey = inputKVIterator.getKey().copy() - firstRowInNextGroup = inputKVIterator.getValue().copy() - numInputRows += 1 - sortedInputHasNewGroup = true - } else { - // This inputIter is empty. - sortedInputHasNewGroup = false - } - } - - initialize() - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { initializeBuffer(sortBasedAggregationBuffer) generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) } } - -object SortBasedAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean, - numInputRows: LongSQLMetric, - numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { - val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { - AggregationIterator.unsafeKVIterator( - groupingExprs, - inputAttributes, - inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] - } else { - AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) - } - - new SortBasedAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - kvIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows, - numInputRows, - numOutputRows) - } - // scalastyle:on -} From fcb37a04177edc2376e39dd0b910f0268f7c72ec Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 12 Oct 2015 09:16:14 -0700 Subject: [PATCH 022/139] [SPARK-10960] [SQL] SQL with windowing function should be able to refer column in inner select JIRA: https://issues.apache.org/jira/browse/SPARK-10960 When accessing a column in inner select from a select with window function, `AnalysisException` will be thrown. For example, an query like this: select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 from (select month, area, product, 1 as tmp1 from windowData) tmp Currently, the rule `ExtractWindowExpressions` in `Analyzer` only extracts regular expressions from `WindowFunction`, `WindowSpecDefinition` and `AggregateExpression`. We need to also extract other attributes as the one in `Alias` as shown in the above query. Author: Liang-Chi Hsieh Closes #9011 from viirya/fix-window-inner-column. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +++ .../sql/hive/execution/SQLQuerySuite.scala | 27 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bf72d47ce1ea6..f5597a08d3595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -831,6 +831,10 @@ class Analyzer( val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute + + // Extracts other attributes + case attr: Attribute => extractExpr(attr) + }.asInstanceOf[NamedExpression] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ccc15eaa63f42..51b63f3688783 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -838,6 +838,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("window function: refer column in inner select block") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product, 1 as tmp1 from windowData) tmp + """.stripMargin), + Seq( + ("a", 2), + ("a", 3), + ("b", 2), + ("b", 3), + ("c", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + } + test("window function: partition and order expressions") { val data = Seq( WindowData(1, "a", 5), From 64b1d00e1a7c1dc52c08a5e97baf6e7117f1a94f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 12 Oct 2015 10:17:19 -0700 Subject: [PATCH 023/139] [SPARK-11007] [SQL] Adds dictionary aware Parquet decimal converters For Parquet decimal columns that are encoded using plain-dictionary encoding, we can make the upper level converter aware of the dictionary, so that we can pre-instantiate all the decimals to avoid duplicated instantiation. Note that plain-dictionary encoding isn't available for `FIXED_LEN_BYTE_ARRAY` for Parquet writer version `PARQUET_1_0`. So currently only decimals written as `INT32` and `INT64` can benefit from this optimization. Author: Cheng Lian Closes #9040 from liancheng/spark-11007.decimal-converter-dict-support. --- .../parquet/CatalystRowConverter.scala | 83 +++++++++++++++--- .../src/test/resources/dec-in-i32.parquet | Bin 0 -> 420 bytes .../src/test/resources/dec-in-i64.parquet | Bin 0 -> 437 bytes .../datasources/parquet/ParquetIOSuite.scala | 19 ++++ .../ParquetProtobufCompatibilitySuite.scala | 22 ++--- .../datasources/parquet/ParquetTest.scala | 5 ++ 6 files changed, 103 insertions(+), 26 deletions(-) create mode 100755 sql/core/src/test/resources/dec-in-i32.parquet create mode 100755 sql/core/src/test/resources/dec-in-i64.parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 247d35363b862..49007e45ecf87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY} import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} import org.apache.spark.Logging @@ -222,8 +222,25 @@ private[parquet] class CatalystRowConverter( updater.setShort(value.asInstanceOf[ShortType#InternalType]) } + // For INT32 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new CatalystIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For INT64 backed decimals + case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + new CatalystLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + + // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals + case t: DecimalType + if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || + parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => + new CatalystBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + case t: DecimalType => - new CatalystDecimalConverter(t, updater) + throw new RuntimeException( + s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " + + s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " + + "FIXED_LEN_BYTE_ARRAY, or BINARY.") case StringType => new CatalystStringConverter(updater) @@ -274,9 +291,10 @@ private[parquet] class CatalystRowConverter( override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) - case _ => + case t => throw new RuntimeException( - s"Unable to create Parquet converter for data type ${catalystType.json}") + s"Unable to create Parquet converter for data type ${t.json} " + + s"whose Parquet type is $parquetType") } } @@ -314,11 +332,18 @@ private[parquet] class CatalystRowConverter( /** * Parquet converter for fixed-precision decimals. */ - private final class CatalystDecimalConverter( - decimalType: DecimalType, - updater: ParentContainerUpdater) + private abstract class CatalystDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) extends CatalystPrimitiveConverter(updater) { + protected var expandedDictionary: Array[Decimal] = _ + + override def hasDictionarySupport: Boolean = true + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { addLong(value: Long) @@ -326,18 +351,19 @@ private[parquet] class CatalystRowConverter( // Converts decimals stored as INT64 override def addLong(value: Long): Unit = { - updater.set(Decimal(value, decimalType.precision, decimalType.scale)) + updater.set(decimalFromLong(value)) } // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY override def addBinary(value: Binary): Unit = { - updater.set(toDecimal(value)) + updater.set(decimalFromBinary(value)) } - private def toDecimal(value: Binary): Decimal = { - val precision = decimalType.precision - val scale = decimalType.scale + protected def decimalFromLong(value: Long): Decimal = { + Decimal(value, precision, scale) + } + protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = binaryToUnscaledLong(value) @@ -371,6 +397,39 @@ private[parquet] class CatalystRowConverter( } } + private class CatalystIntDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToInt(id).toLong) + } + } + } + + private class CatalystLongDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromLong(dictionary.decodeToLong(id)) + } + } + } + + private class CatalystBinaryDictionaryAwareDecimalConverter( + precision: Int, scale: Int, updater: ParentContainerUpdater) + extends CatalystDecimalConverter(precision, scale, updater) { + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => + decimalFromBinary(dictionary.decodeToBinary(id)) + } + } + } + /** * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard * Parquet lists are represented as a 3-level group annotated by `LIST`: diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/dec-in-i32.parquet new file mode 100755 index 0000000000000000000000000000000000000000..bb5d4af8dd36817bfb2cc16746417f0f2cbec759 GIT binary patch literal 420 zcmWG=3^EjD5e*Pc^AQyhWno~D@8)2DfaHXP1P{hX!VYT=pEzK^*s-6XkVTmJu$)3z zLRvxwV-mx>L&=vkfNDhP{zKtf8Q zr~#MeY(^CZr+?eF2>?}yGD+%q@Dvv$7G=j5CugMQCWub(penut~xdh_Z+&h@D{+YhzO5u)%NwNJfD{Qbs~EzbIWVu^ zi5}QKz2d?gJ)p&frKu%)Mfv4=xv3?IDTyVC63Nv{C6xuKN>)n6B}JvlB}zI=l}^8Q8rNy83~RSW{3$AFryg6KmtfcCnY2VB%~yY z8gOaOW>jHt`nPSH0LUmNNgWTK;)2AY?D*p3jMUsjQ6>ga7F8w*_DnOA_>|OSRW6_{ zA`D^*k}{GqY8*16ERv=y9Bh(s1)?ls3S#S+#HKN+aoFH=3P^*c1FB&H;mBub=IE0t6hq$*h{6_*s1CYLDbD5Yhl z=A;xWSw&YX hadoopConfiguration.set(entry.getKey, entry.getValue)) } } + + test("read dictionary encoded decimals written as INT32") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i32.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + } + + test("read dictionary encoded decimals written as INT64") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i64.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + } + + // TODO Adds test case for reading dictionary encoded decimals written as `FIXED_LEN_BYTE_ARRAY` + // The Parquet writer version Spark 1.6 and prior versions use is `PARQUET_1_0`, which doesn't + // provide dictionary encoding support for `FIXED_LEN_BYTE_ARRAY`. Should add a test here once + // we upgrade to `PARQUET_2_0`. } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index b290429c2a021..98333e58cada8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -17,23 +17,17 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSQLContext class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { - - private def readParquetProtobufFile(name: String): DataFrame = { - val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) - } - test("unannotated array of primitive type") { - checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) } test("unannotated array of struct") { checkAnswer( - readParquetProtobufFile("old-repeated-message.parquet"), + readResourceParquetFile("old-repeated-message.parquet"), Row( Seq( Row("First inner", null, null), @@ -41,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh Row(null, null, "Third inner")))) checkAnswer( - readParquetProtobufFile("proto-repeated-struct.parquet"), + readResourceParquetFile("proto-repeated-struct.parquet"), Row( Seq( Row("0 - 1", "0 - 2", "0 - 3"), Row("1 - 1", "1 - 2", "1 - 3")))) checkAnswer( - readParquetProtobufFile("proto-struct-with-array-many.parquet"), + readResourceParquetFile("proto-struct-with-array-many.parquet"), Seq( Row( Seq( @@ -66,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("struct with unannotated array") { checkAnswer( - readParquetProtobufFile("proto-struct-with-array.parquet"), + readResourceParquetFile("proto-struct-with-array.parquet"), Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) } test("unannotated array of struct with unannotated array") { checkAnswer( - readParquetProtobufFile("nested-array-struct.parquet"), + readResourceParquetFile("nested-array-struct.parquet"), Seq( Row(2, Seq(Row(1, Seq(Row(3))))), Row(5, Seq(Row(4, Seq(Row(6))))), @@ -81,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("unannotated array of string") { checkAnswer( - readParquetProtobufFile("proto-repeated-string.parquet"), + readResourceParquetFile("proto-repeated-string.parquet"), Seq( Row(Seq("hello", "world")), Row(Seq("good", "bye")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 9840ad919e510..8ffb01fc5b584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -139,4 +139,9 @@ private[sql] trait ParquetTest extends SQLTestUtils { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { f } } } + + protected def readResourceParquetFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } } From 149472a01d12828c64b0a852982d48c123984182 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 12 Oct 2015 10:21:57 -0700 Subject: [PATCH 024/139] [SPARK-11023] [YARN] Avoid creating URIs from local paths directly. The issue is that local paths on Windows, when provided with drive letters or backslashes, are not valid URIs. Instead of trying to figure out whether paths are URIs or not, use Utils.resolveURI() which does that for us. Author: Marcelo Vanzin Closes #9049 from vanzin/SPARK-11023 and squashes the following commits: 77021f2 [Marcelo Vanzin] [SPARK-11023] [yarn] Avoid creating URIs from local paths directly. --- .../scala/org/apache/spark/deploy/yarn/Client.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index cec81b940644c..1fbd18aa466d4 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -358,7 +358,8 @@ private[spark] class Client( destName: Option[String] = None, targetDir: Option[String] = None, appMasterOnly: Boolean = false): (Boolean, String) = { - val localURI = new URI(path.trim()) + val trimmedPath = path.trim() + val localURI = Utils.resolveURI(trimmedPath) if (localURI.getScheme != LOCAL_SCHEME) { if (addDistributedUri(localURI)) { val localPath = getQualifiedLocalPath(localURI, hadoopConf) @@ -374,7 +375,7 @@ private[spark] class Client( (false, null) } } else { - (true, path.trim()) + (true, trimmedPath) } } @@ -595,10 +596,10 @@ private[spark] class Client( LOCALIZED_PYTHON_DIR) } (pySparkArchives ++ pyArchives).foreach { path => - val uri = new URI(path) + val uri = Utils.resolveURI(path) if (uri.getScheme != LOCAL_SCHEME) { pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), - new Path(path).getName()) + new Path(uri).getName()) } else { pythonPath += uri.getPath() } @@ -1229,7 +1230,7 @@ object Client extends Logging { private def getMainJarUri(mainJar: Option[String]): Option[URI] = { mainJar.flatMap { path => - val uri = new URI(path) + val uri = Utils.resolveURI(path) if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None }.orElse(Some(new URI(APP_JAR))) } From 2e572c4135c3f5ad3061c1f58cdb8a70bed0a9d3 Mon Sep 17 00:00:00 2001 From: Ashwin Shankar Date: Mon, 12 Oct 2015 11:06:21 -0700 Subject: [PATCH 025/139] [SPARK-8170] [PYTHON] Add signal handler to trap Ctrl-C in pyspark and cancel all running jobs This patch adds a signal handler to trap Ctrl-C and cancels running job. Author: Ashwin Shankar Closes #9033 from ashwinshankar77/master. --- python/pyspark/context.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a0a1ccbeefb09..4969d85f52b23 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -19,6 +19,7 @@ import os import shutil +import signal import sys from threading import Lock from tempfile import NamedTemporaryFile @@ -217,6 +218,12 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, else: self.profiler_collector = None + # create a signal handler which would be invoked on receiving SIGINT + def signal_handler(signal, frame): + self.cancelAllJobs() + + signal.signal(signal.SIGINT, signal_handler) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization From 8a354bef55ce9cc0fa77fa1c3a9d62c16438ca1b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 12 Oct 2015 13:50:34 -0700 Subject: [PATCH 026/139] [SPARK-11042] [SQL] Add a mechanism to ban creating multiple root SQLContexts/HiveContexts in a JVM https://issues.apache.org/jira/browse/SPARK-11042 Author: Yin Huai Closes #9058 from yhuai/SPARK-11042. --- .../scala/org/apache/spark/sql/SQLConf.scala | 10 ++ .../org/apache/spark/sql/SQLContext.scala | 42 +++++++- .../spark/sql/MultiSQLContextsSuite.scala | 99 +++++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 12 ++- 4 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 47397c4be3cb6..f62df9bdebcc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -186,6 +186,16 @@ private[spark] object SQLConf { import SQLConfEntry._ + val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts", + defaultValue = Some(true), + doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed." + + "When set to false, only one SQLContext/HiveContext is allowed to be created " + + "through the constructor (new SQLContexts/HiveContexts created through newSession " + + "method is allowed). Please note that this conf needs to be set in Spark Conf. Once" + + "a SQLContext/HiveContext has been created, changing the value of this conf will not" + + "have effect.", + isPublic = true) + val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", defaultValue = Some(true), doc = "When set to true Spark SQL will automatically select a compression codec for each " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2bdfd82af0adb..1bd291389241a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,7 +26,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.SparkContext +import org.apache.spark.{SparkException, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -64,14 +64,37 @@ import org.apache.spark.util.Utils */ class SQLContext private[sql]( @transient val sparkContext: SparkContext, - @transient protected[sql] val cacheManager: CacheManager) + @transient protected[sql] val cacheManager: CacheManager, + val isRootContext: Boolean) extends org.apache.spark.Logging with Serializable { self => - def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager) + def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager, true) def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) + // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user + // wants to create a new root SQLContext (a SLQContext that is not created by newSession). + private val allowMultipleContexts = + sparkContext.conf.getBoolean( + SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, + SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get) + + // Assert no root SQLContext is running when allowMultipleContexts is false. + { + if (!allowMultipleContexts && isRootContext) { + SQLContext.getInstantiatedContextOption() match { + case Some(rootSQLContext) => + val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " + + s"It is recommended to use SQLContext.getOrCreate to get the instantiated " + + s"SQLContext/HiveContext. To ignore this error, " + + s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf." + throw new SparkException(errMsg) + case None => // OK + } + } + } + /** * Returns a SQLContext as new session, with separated SQL configurations, temporary tables, * registered functions, but sharing the same SparkContext and CacheManager. @@ -79,7 +102,10 @@ class SQLContext private[sql]( * @since 1.6.0 */ def newSession(): SQLContext = { - new SQLContext(sparkContext, cacheManager) + new SQLContext( + sparkContext = sparkContext, + cacheManager = cacheManager, + isRootContext = false) } /** @@ -1239,6 +1265,10 @@ object SQLContext { instantiatedContext.compareAndSet(null, sqlContext) } + private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { + Option(instantiatedContext.get()) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1260,6 +1290,10 @@ object SQLContext { activeContext.remove() } + private[sql] def getActiveContextOption(): Option[SQLContext] = { + Option(activeContext.get()) + } + /** * Converts an iterator of Java Beans to InternalRow using the provided * bean info & schema. This is not related to the singleton, but is a static diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala new file mode 100644 index 0000000000000..0e8fcb6a858b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -0,0 +1,99 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark._ +import org.scalatest.BeforeAndAfterAll + +class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSQLContext: Option[SQLContext] = _ + private var originalInstantiatedSQLContext: Option[SQLContext] = _ + private var sparkConf: SparkConf = _ + + override protected def beforeAll(): Unit = { + originalActiveSQLContext = SQLContext.getActiveContextOption() + originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + + SQLContext.clearActive() + originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + } + + def testNewSession(rootSQLContext: SQLContext): Unit = { + // Make sure we can successfully create new Session. + rootSQLContext.newSession() + + // Reset the state. It is always safe to clear the active context. + SQLContext.clearActive() + } + + def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = { + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString) + val sparkContext = new SparkContext(conf) + + try { + if (allowsMultipleContexts) { + new SQLContext(sparkContext) + SQLContext.clearActive() + } else { + // If allowsMultipleContexts is false, make sure we can get the error. + val message = intercept[SparkException] { + new SQLContext(sparkContext) + }.getMessage + assert(message.contains("Only one SQLContext/HiveContext may be running")) + } + } finally { + sparkContext.stop() + } + } + + test("test the flag to disallow creating multiple root SQLContext") { + Seq(false, true).foreach { allowMultipleSQLContexts => + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString) + val sc = new SparkContext(conf) + try { + val rootSQLContext = new SQLContext(sc) + testNewSession(rootSQLContext) + testNewSession(rootSQLContext) + testCreatingNewSQLContext(allowMultipleSQLContexts) + + SQLContext.clearInstantiatedContext(rootSQLContext) + } finally { + sc.stop() + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index dad1e2347c387..ddeadd3eb737d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -89,10 +89,11 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, @transient execHive: ClientWrapper, - @transient metaHive: ClientInterface) extends SQLContext(sc, cacheManager) with Logging { + @transient metaHive: ClientInterface, + isRootContext: Boolean) extends SQLContext(sc, cacheManager, isRootContext) with Logging { self => - def this(sc: SparkContext) = this(sc, new CacheManager, null, null) + def this(sc: SparkContext) = this(sc, new CacheManager, null, null, true) def this(sc: JavaSparkContext) = this(sc.sc) import org.apache.spark.sql.hive.HiveContext._ @@ -105,7 +106,12 @@ class HiveContext private[hive]( * and Hive client (both of execution and metadata) with existing HiveContext. */ override def newSession(): HiveContext = { - new HiveContext(sc, cacheManager, executionHive.newSession(), metadataHive.newSession()) + new HiveContext( + sc = sc, + cacheManager = cacheManager, + execHive = executionHive.newSession(), + metaHive = metadataHive.newSession(), + isRootContext = false) } /** From 091c2c3ecd69803d78c2b15a1487046701059d38 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Mon, 12 Oct 2015 14:23:29 -0700 Subject: [PATCH 027/139] [SPARK-11056] Improve documentation of SBT build. This commit improves the documentation around building Spark to (1) recommend using SBT interactive mode to avoid the overhead of launching SBT and (2) refer to the wiki page that documents using SPARK_PREPEND_CLASSES to avoid creating the assembly jar for each compile. cc srowen Author: Kay Ousterhout Closes #9068 from kayousterhout/SPARK-11056. --- docs/building-spark.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index 4d929ee10a33f..743643cbcc62f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -216,6 +216,11 @@ can be set to control the SBT build. For example: build/sbt -Pyarn -Phadoop-2.3 assembly +To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt +in interactive mode by running `build/sbt`, and then run all build commands at the command +prompt. For more recommendations on reducing build time, refer to the +[wiki page](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-ReducingBuildTimes). + # Testing with SBT Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: From f97e9323b526b3d0b0fee0ca03f4276f37bb5750 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 12 Oct 2015 18:17:28 -0700 Subject: [PATCH 028/139] [SPARK-10739] [YARN] Add application attempt window for Spark on Yarn Add application attempt window for Spark on Yarn to ignore old out of window failures, this is useful for long running applications to recover from failures. Author: jerryshao Closes #8857 from jerryshao/SPARK-10739 and squashes the following commits: 36eabdc [jerryshao] change the doc 7f9b77d [jerryshao] Style change 1c9afd0 [jerryshao] Address the comments caca695 [jerryshao] Add application attempt window for Spark on Yarn --- docs/running-on-yarn.md | 9 +++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 6d77db6a3271e..677c0000440ac 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -305,6 +305,15 @@ If you need a reference to the proper location to put log files in the YARN so t It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.am.attemptFailuresValidityInterval + (none) + + Defines the validity interval for AM failure tracking. + If the AM has been running for at least the defined interval, the AM failure count will be reset. + This feature is not enabled if not configured, and only supported in Hadoop 2.6+. + + spark.yarn.submit.waitAppCompletion true diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fbd18aa466d4..d25d830fd4349 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -208,6 +208,20 @@ private[spark] class Client( case None => logDebug("spark.yarn.maxAppAttempts is not set. " + "Cluster's default value will be used.") } + + if (sparkConf.contains("spark.yarn.am.attemptFailuresValidityInterval")) { + try { + val interval = sparkConf.getTimeAsMs("spark.yarn.am.attemptFailuresValidityInterval") + val method = appContext.getClass().getMethod( + "setAttemptFailuresValidityInterval", classOf[Long]) + method.invoke(appContext, interval: java.lang.Long) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.attemptFailuresValidityInterval because the version " + + "of YARN does not support it") + } + } + val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) From c4da5345a0ef643a7518756caaa18ff3f3ea9acc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 21:12:59 -0700 Subject: [PATCH 029/139] [SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types This PR improve the unrolling and read of complex types in columnar cache: 1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize) 2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[] 3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy. Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less). ``` df = sqlContext.read.parquet(path) t = time.time() df.cache() df.count() print 'unrolling', time.time() - t for i in range(10): t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` The schema is ``` root |-- a: struct (nullable = true) | |-- b: long (nullable = true) | |-- c: string (nullable = true) |-- d: array (nullable = true) | |-- element: long (containsNull = true) |-- e: map (nullable = true) | |-- key: long | |-- value: string (valueContainsNull = true) ``` Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that. Author: Davies Liu Closes #9016 from davies/complex2. --- .../catalyst/expressions/UnsafeArrayData.java | 12 ++ .../sql/catalyst/expressions/UnsafeRow.java | 12 ++ .../expressions/codegen/CodeGenerator.scala | 5 + .../codegen/GenerateSafeProjection.scala | 1 + .../codegen/GenerateUnsafeProjection.scala | 29 ++- .../spark/sql/columnar/ColumnAccessor.scala | 9 +- .../spark/sql/columnar/ColumnType.scala | 187 +++++++++--------- .../columnar/InMemoryColumnarTableScan.scala | 6 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 37 ++-- .../NullableColumnAccessorSuite.scala | 7 +- .../columnar/NullableColumnBuilderSuite.scala | 13 +- .../apache/spark/unsafe/types/UTF8String.java | 10 + 12 files changed, 188 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index fdd9125613a26..796f8abec9a1d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -145,6 +146,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -306,6 +309,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5af7ed5d6eb6d..36859fbab9744 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,6 +20,7 @@ import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -326,6 +327,8 @@ public Object get(int ordinal, DataType dataType) { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -602,6 +605,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert (buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public void writeExternal(ObjectOutput out) throws IOException { byte[] bytes = getBytes(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a0fe5bd77e3aa..7544d27e3dc15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -129,6 +129,7 @@ class CodeGenContext { case _: ArrayType => s"$input.getArray($ordinal)" case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) case _ => s"($jt)$input.get($ordinal, null)" } } @@ -143,6 +144,7 @@ class CodeGenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } } @@ -177,6 +179,7 @@ class CodeGenContext { case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName case _ => "Object" @@ -222,6 +225,7 @@ class CodeGenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case other => s"$c1.equals($c2)" } @@ -255,6 +259,7 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 9873630937d31..ee50587ed097e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => GeneratedExpressionCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3e0e81733fb1f..1b957a508d10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case dt: OpenHashSetUDT => false // it's not a standard UDT + case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dt), index) => + case ((input, dataType), index) => + val dt = dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { @@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val index = ctx.freshName("index") val element = ctx.freshName("element") - val jt = ctx.javaType(elementType) + val et = elementType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } + + val jt = ctx.javaType(et) - val fixedElementSize = elementType match { + val fixedElementSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize + case _ if ctx.isPrimitiveType(jt) => et.defaultSize case _ => 0 } - val writeElement = elementType match { + val writeElement = et match { case t: StructType => s""" $arrayWriter.setOffset($index); @@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(elementType) => + case _ if ctx.isPrimitiveType(et) => // Should we do word align? - val dataSize = elementType.defaultSize + val dataSize = et.defaultSize s""" $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, elementType, + ${writePrimitiveType(ctx, element, et, s"$bufferHolder.buffer", s"$bufferHolder.cursor")} $bufferHolder.cursor += $dataSize; """ @@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($input.isNullAt($index)) { $arrayWriter.setNullAt($index); } else { - final $jt $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 62478667eb4fb..42ec4d3433f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ @@ -109,15 +108,15 @@ private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalTy with NullableColumnAccessor private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) - extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType)) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) - extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType)) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) - extends BasicColumnAccessor[MapData](buffer, MAP(dataType)) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor private[sql] object ColumnAccessor { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 3563eacb3a3e9..2bc2c96b61634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.math.{BigDecimal, BigInteger} -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag @@ -92,7 +92,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.update(toOrdinal, from.get(fromOrdinal, dataType)) + setField(to, toOrdinal, getField(from, fromOrdinal)) } /** @@ -147,6 +147,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } @@ -324,15 +325,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes - buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) + buffer.putInt(v.numBytes()) + v.writeTo(buffer) } override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - val stringBytes = new Array[Byte](length) - buffer.get(stringBytes, 0, length) - UTF8String.fromBytes(stringBytes) + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(base, offset + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { @@ -386,11 +390,6 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType - override def actualSize(row: InternalRow, ordinal: Int): Int = { - // TODO: grow the buffer in append(), so serialize() will not be called twice - serialize(getField(row, ordinal)).length + 4 - } - override def append(v: JvmType, buffer: ByteBuffer): Unit = { val bytes = serialize(v) buffer.putInt(bytes.length).put(bytes, 0, bytes.length) @@ -416,6 +415,10 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { row.getBinary(ordinal) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getBinary(ordinal).length + 4 + } + def serialize(value: Array[Byte]): Array[Byte] = value def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } @@ -433,6 +436,10 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) row.setDecimal(ordinal, value, precision) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + override def serialize(value: Decimal): Array[Byte] = { value.toJavaBigDecimal.unscaledValue().toByteArray } @@ -449,124 +456,118 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) - extends ByteArrayColumnType[InternalRow](20) { +private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { - private val projection: UnsafeProjection = - UnsafeProjection.create(dataType) private val numOfFields: Int = dataType.fields.size - override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = { + override def defaultSize: Int = 20 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): InternalRow = { - row.getStruct(ordinal, numOfFields) + override def getField(row: InternalRow, ordinal: Int): UnsafeRow = { + row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow] } - override def serialize(value: InternalRow): Array[Byte] = { - val unsafeRow = if (value.isInstanceOf[UnsafeRow]) { - value.asInstanceOf[UnsafeRow] - } else { - projection(value) - } - unsafeRow.getBytes + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).getSizeInBytes } - override def deserialize(bytes: Array[Byte]): InternalRow = { + override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeRow = { + val sizeInBytes = buffer.getInt() + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + buffer.position(cursor + sizeInBytes) val unsafeRow = new UnsafeRow - unsafeRow.pointTo(bytes, numOfFields, bytes.length) + unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) unsafeRow } - override def clone(v: InternalRow): InternalRow = v.copy() + override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) - extends ByteArrayColumnType[ArrayData](16) { +private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { - private lazy val projection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) + override def defaultSize: Int = 16 - override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): ArrayData = { - row.getArray(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = { + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] } - override def serialize(value: ArrayData): Array[Byte] = { - val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) { - value.asInstanceOf[UnsafeArrayData] - } else { - mutableRow(0) = value - projection(mutableRow).getArray(0) - } - val outputBuffer = - ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeArray.numElements()) - val underlying = outputBuffer.array() - unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) - underlying + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeArray = getField(row, ordinal) + 4 + 4 + unsafeArray.getSizeInBytes } - override def deserialize(bytes: Array[Byte]): ArrayData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) - val numElements = buffer.getInt - val array = new UnsafeArrayData - array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4) - array + override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(4 + value.getSizeInBytes) + buffer.putInt(value.numElements()) + value.writeTo(buffer) } - override def clone(v: ArrayData): ArrayData = v.copy() + override def extract(buffer: ByteBuffer): UnsafeArrayData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + UnsafeReaders.readArray( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + } + + override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) { +private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { - private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) + override def defaultSize: Int = 32 - override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): MapData = { - row.getMap(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = { + row.getMap(ordinal).asInstanceOf[UnsafeMapData] } - override def serialize(value: MapData): Array[Byte] = { - val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) { - value.asInstanceOf[UnsafeMapData] - } else { - mutableRow(0) = value - projection(mutableRow).getMap(0) - } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeMap = getField(row, ordinal) + 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + } + + override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes) + buffer.putInt(value.numElements()) + buffer.putInt(value.keyArray().getSizeInBytes) + value.keyArray().writeTo(buffer) + value.valueArray().writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeMapData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + UnsafeReaders.readMap( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + } - val outputBuffer = - ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeMap.numElements()) - val keyBytes = unsafeMap.keyArray().getSizeInBytes - outputBuffer.putInt(keyBytes) - val underlying = outputBuffer.array() - unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) - unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) - underlying - } - - override def deserialize(bytes: Array[Byte]): MapData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) - val numElements = buffer.getInt - val keyArraySize = buffer.getInt - val keyArray = new UnsafeArrayData - val valueArray = new UnsafeArrayData - keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize) - valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements, - bytes.length - 8 - keyArraySize) - new UnsafeMapData(keyArray, valueArray) - } - - override def clone(v: MapData): MapData = v.copy() + override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } private[sql] object ColumnType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d7e145f9c2bb8..d967814f627cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -38,7 +38,9 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, + if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), + tableName)() } private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index ceb8ad97bb320..0e6e1bcf72896 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer +import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(expected, s"Wrong actualSize for $columnType") { val row = new GenericMutableRow(1) row.update(0, CatalystTypeConverters.convertToCatalyst(value)) - columnType.actualSize(row, 0) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + columnType.actualSize(proj(row), 0) } } @@ -99,35 +100,27 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val seq = (0 until 4).map(_ => makeRandomValue(columnType)) + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) + val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) test(s"$columnType append/extract") { buffer.rewind() - seq.foreach(columnType.append(_, buffer)) + seq.foreach(columnType.append(_, 0, buffer)) buffer.rewind() - seq.foreach { expected => - logInfo("buffer = " + buffer + ", expected = " + expected) - val extracted = columnType.extract(buffer) - assert( - converter(expected) === converter(extracted), - "Extracted value didn't equal to the original one. " + - hexDump(expected) + " != " + hexDump(extracted) + - ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) + seq.foreach { row => + logInfo("buffer = " + buffer + ", expected = " + row) + val expected = converter(row.get(0, columnType.dataType)) + val extracted = converter(columnType.extract(buffer)) + assert(expected === extracted, + s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" + + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) } } } - private def hexDump(value: Any): String = { - if (value == null) { - "" - } else { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") - } - } - private def dumpBuffer(buff: ByteBuffer): Any = { val sb = new StringBuilder() while (buff.hasRemaining) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 78cebbf3cc934..aa1605fee8c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -64,10 +64,11 @@ class NullableColumnAccessorSuite extends SparkFunSuite { test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => - builder.appendFrom(randomRow, 0) - builder.appendFrom(nullRow, 0) + builder.appendFrom(proj(randomRow), 0) + builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index fba08e626d720..91404577832a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -51,6 +51,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val dataType = columnType.dataType + val proj = UnsafeProjection.create(Array[DataType](dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) @@ -65,7 +68,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() @@ -77,12 +80,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) - val dataType = columnType.dataType - val converter = CatalystTypeConverters.createToScalaConverter(dataType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) - columnBuilder.appendFrom(nullRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) + columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 216aeea60d1c8..b7aecb5102ba6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -19,6 +19,7 @@ import javax.annotation.Nonnull; import java.io.*; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -137,6 +138,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(base, offset, target, targetOffset, numBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + numBytes); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point From 626aab79c9b4d4ac9d65bf5fa45b81dd9cbc609c Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 13 Oct 2015 08:29:47 -0500 Subject: [PATCH 030/139] [SPARK-11026] [YARN] spark.yarn.user.classpath.first does work for 'spark-submit --jars hdfs://user/foo.jar' when spark.yarn.user.classpath.first=true and using 'spark-submit --jars hdfs://user/foo.jar', it can not put foo.jar to system classpath. so we need to put yarn's linkNames of jars to the system classpath. vanzin tgravescs Author: Lianhui Wang Closes #9045 from lianhuiwang/spark-11026. --- .../org/apache/spark/deploy/yarn/Client.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d25d830fd4349..9fcfe362a3ba2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1212,7 +1212,7 @@ object Client extends Logging { } else { getMainJarUri(sparkConf.getOption(CONF_SPARK_USER_JAR)) } - mainJar.foreach(addFileToClasspath(sparkConf, _, APP_JAR, env)) + mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR, env)) val secondaryJars = if (args != null) { @@ -1221,10 +1221,10 @@ object Client extends Logging { getSecondaryJarUris(sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) } secondaryJars.foreach { x => - addFileToClasspath(sparkConf, x, null, env) + addFileToClasspath(sparkConf, conf, x, null, env) } } - addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) @@ -1259,15 +1259,17 @@ object Client extends Logging { * If an alternate name for the file is given, and it's not a "local:" file, the alternate * name will be added to the classpath (relative to the job's work directory). * - * If not a "local:" file and no alternate name, the environment is not modified. + * If not a "local:" file and no alternate name, the linkName will be added to the classpath. * - * @param conf Spark configuration. - * @param uri URI to add to classpath (optional). - * @param fileName Alternate name for the file (optional). - * @param env Map holding the environment variables. + * @param conf Spark configuration. + * @param hadoopConf Hadoop configuration. + * @param uri URI to add to classpath (optional). + * @param fileName Alternate name for the file (optional). + * @param env Map holding the environment variables. */ private def addFileToClasspath( conf: SparkConf, + hadoopConf: Configuration, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { @@ -1276,6 +1278,11 @@ object Client extends Logging { } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) + } else if (uri != null) { + val localPath = getQualifiedLocalPath(uri, hadoopConf) + val linkName = Option(uri.getFragment()).getOrElse(localPath.getName()) + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), linkName), env) } } From 6987c067937a50867b4d5788f5bf496ecdfdb62c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 13 Oct 2015 09:40:36 -0700 Subject: [PATCH 031/139] [SPARK-11009] [SQL] fix wrong result of Window function in cluster mode Currently, All windows function could generate wrong result in cluster sometimes. The root cause is that AttributeReference is called in executor, then id of it may not be unique than others created in driver. Here is the script that could reproduce the problem (run in local cluster): ``` from pyspark import SparkContext, HiveContext from pyspark.sql.window import Window from pyspark.sql.functions import rowNumber sqlContext = HiveContext(SparkContext()) sqlContext.setConf("spark.sql.shuffle.partitions", "3") df = sqlContext.range(1<<20) df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B')) ws = Window.partitionBy(df2.A).orderBy(df2.B) df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn < 0") assert df3.count() == 0 ``` Author: Davies Liu Author: Yin Huai Closes #9050 from davies/wrong_window. --- .../apache/spark/sql/execution/Window.scala | 20 ++++----- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 41 +++++++++++++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index f8929530c5036..55035f4bc5f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -145,11 +145,10 @@ case class Window( // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val (sortExprs, schema) = exprs.map { case e => - val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() - (SortOrder(ref, e.direction), ref) - }.unzip - val ordering = newOrdering(sortExprs, schema) + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) RangeBoundOrdering(ordering, current, bound) case RowFrame => RowBoundOrdering(offset) } @@ -205,14 +204,15 @@ case class Window( */ private[this] def createResultProjection( expressions: Seq[Expression]): MutableProjection = { - val unboundToAttr = expressions.map { - e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) } - val unboundToAttrMap = unboundToAttr.toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) newMutableProjection( projectList ++ patchedWindowExpression, - child.output ++ unboundToAttr.map(_._2))() + child.output)() } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 5f1660b62d418..10e4ae2c50308 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.DecimalType @@ -107,6 +108,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-11009 fix wrong result of Window function in cluster mode") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_11009.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest { } } } + +object SPARK_11009 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + try { + val df = sqlContext.range(1 << 20) + val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) + val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) + val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + if (df3.rdd.count() != 0) { + throw new Exception("df3 should have 0 output row.") + } + } finally { + sparkContext.stop() + } + } +} From 1797055dbf1d2fd7714d7c65c8d2efde2f15efc1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 13 Oct 2015 09:51:20 -0700 Subject: [PATCH 032/139] [SPARK-11079] Post-hoc review Netty-based RPC - round 1 I'm going through the implementation right now for post-doc review. Adding more comments and renaming things as I go through them. I also want to write higher level documentation about how the whole thing works -- but those will come in other pull requests. Author: Reynold Xin Closes #9091 from rxin/rpc-review. --- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../org/apache/spark/rpc/RpcAddress.scala | 50 ++++++ .../org/apache/spark/rpc/RpcEndpoint.scala | 3 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 153 +----------------- .../org/apache/spark/rpc/RpcTimeout.scala | 131 +++++++++++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 - .../apache/spark/rpc/netty/Dispatcher.scala | 108 +++++++------ .../apache/spark/rpc/netty/IDVerifier.scala | 4 +- .../org/apache/spark/rpc/netty/Inbox.scala | 119 ++++++-------- .../spark/rpc/netty/NettyRpcCallContext.scala | 11 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 38 +++-- .../org/apache/spark/util/ThreadUtils.scala | 1 - .../scala/org/apache/spark/util/Utils.scala | 1 + .../apache/spark/rpc/netty/InboxSuite.scala | 6 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 7 +- 15 files changed, 336 insertions(+), 302 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 45e12e40c837f..72355cdfa68b3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -48,7 +48,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.size + val serializedSize = mapOutputStatuses.length if (serializedSize > maxAkkaFrameSize) { val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala new file mode 100644 index 0000000000000..eb0b26947f504 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.util.Utils + + +/** + * Address for an RPC environment, with hostname and port. + */ +private[spark] case class RpcAddress(host: String, port: Int) { + + def hostPort: String = host + ":" + port + + /** Returns a string in the form of "spark://host:port". */ + def toSparkURL: String = "spark://" + hostPort + + override def toString: String = hostPort +} + + +private[spark] object RpcAddress { + + /** Return the [[RpcAddress]] represented by `uri`. */ + def fromURIString(uri: String): RpcAddress = { + val uriObj = new java.net.URI(uri) + RpcAddress(uriObj.getHost, uriObj.getPort) + } + + /** Returns the [[RpcAddress]] encoded in the form of "spark://host:port" */ + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index f1ddc6d2cd438..0ba95169529e6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -145,5 +145,4 @@ private[spark] trait RpcEndpoint { * However, there is no guarantee that the same thread will be executing the same * [[ThreadSafeRpcEndpoint]] for different messages. */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { -} +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 35e402c725331..ef491a0ae4f09 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,12 +17,7 @@ package org.apache.spark.rpc -import java.net.URI -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Awaitable, Await, Future} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} @@ -35,8 +30,8 @@ import org.apache.spark.util.{RpcUtils, Utils} private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + val rpcEnvNames = Map( + "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) @@ -53,7 +48,6 @@ private[spark] object RpcEnv { val config = RpcEnvConfig(conf, name, host, port, securityManager) getRpcEnvFactory(conf).create(config) } - } @@ -155,144 +149,3 @@ private[spark] case class RpcEnvConfig( host: String, port: Int, securityManager: SecurityManager) - - -/** - * Represents a host and port. - */ -private[spark] case class RpcAddress(host: String, port: Int) { - // TODO do we need to add the type of RpcEnv in the address? - - val hostPort: String = host + ":" + port - - override val toString: String = hostPort - - def toSparkURL: String = "spark://" + hostPort -} - - -private[spark] object RpcAddress { - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURI(uri: URI): RpcAddress = { - RpcAddress(uri.getHost, uri.getPort) - } - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURIString(uri: String): RpcAddress = { - fromURI(new java.net.URI(uri)) - } - - def fromSparkURL(sparkUrl: String): RpcAddress = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - RpcAddress(host, port) - } -} - - -/** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. - */ -private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) - extends TimeoutException(message) { initCause(cause) } - - -/** - * Associates a timeout with a description so that a when a TimeoutException occurs, additional - * context about the timeout can be amended to the exception message. - * @param duration timeout duration in seconds - * @param timeoutProp the configuration property that controls this timeout - */ -private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) - extends Serializable { - - /** Amends the standard message of TimeoutException to include the description */ - private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { - new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) - } - - /** - * PartialFunction to match a TimeoutException and add the timeout description to the message - * - * @note This can be used in the recover callback of a Future to add to a TimeoutException - * Example: - * val timeout = new RpcTimeout(5 millis, "short timeout") - * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) - */ - def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { - // The exception has already been converted to a RpcTimeoutException so just raise it - case rte: RpcTimeoutException => throw rte - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - } - - /** - * Wait for the completed result and return it. If the result is not available within this - * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` - * is still not ready - */ - def awaitResult[T](awaitable: Awaitable[T]): T = { - try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout - } -} - - -private[spark] object RpcTimeout { - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @throws NoSuchElementException if property is not set - */ - def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @param defaultValue default timeout value in seconds if property not found - */ - def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup prioritized list of timeout properties in the configuration - * and create a RpcTimeout with the first set property key in the - * description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutPropList prioritized list of property keys for the timeout in seconds - * @param defaultValue default timeout value in seconds if no properties found - */ - def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { - require(timeoutPropList.nonEmpty) - - // Find the first set property or use the default value with the first property - val itr = timeoutPropList.iterator - var foundProp: Option[(String, String)] = None - while (itr.hasNext && foundProp.isEmpty){ - val propKey = itr.next() - conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } - } - val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) - val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } - new RpcTimeout(timeout, finalProp._1) - } -} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala new file mode 100644 index 0000000000000..285786ebf9f1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Awaitable, Await} +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 95132a4e4a0bf..3fad595a0d0b0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -39,10 +39,6 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and * remove Akka from the dependencies. - * - * @param actorSystem - * @param conf - * @param boundPort */ private[spark] class AkkaRpcEnv private[akka] ( val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index d71e6f01dbb29..398e9eafc1444 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -38,12 +38,16 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val inbox = new Inbox(ref, endpoint) } - private val endpoints = new ConcurrentHashMap[String, EndpointData]() - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + private val endpoints = new ConcurrentHashMap[String, EndpointData] + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[EndpointData]() + /** + * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced + * immediately. + */ @GuardedBy("this") private var stopped = false @@ -59,7 +63,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } val data = endpoints.get(name) endpointRefs.put(data.endpoint, data.ref) - receivers.put(data) + receivers.put(data) // for the OnStart message } endpointRef } @@ -73,7 +77,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val data = endpoints.remove(name) if (data != null) { data.inbox.stop() - receivers.put(data) + receivers.put(data) // for the OnStop message } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -91,19 +95,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } /** - * Send a message to all registered [[RpcEndpoint]]s. - * @param message + * Send a message to all registered [[RpcEndpoint]]s in this process. + * + * This can be used to make network events known to all end points (e.g. "a new node connected"). */ - def broadcastMessage(message: InboxMessage): Unit = { + def postToAll(message: InboxMessage): Unit = { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessageToInbox(name, (_) => message, - () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) + postMessage( + name, + _ => message, + () => { logWarning(s"Drop $message because $name has been stopped") }) } } - def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + /** Posts a message sent by a remote endpoint. */ + def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new RemoteNettyRpcCallContext( @@ -116,10 +124,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { + /** Posts a message sent by a local endpoint. */ + def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) @@ -131,39 +140,36 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) + postMessage(message.receiver.name, createMessage, onEndpointStopped) } - private def postMessageToInbox( + /** + * Posts a message to a specific endpoint. + * + * @param endpointName name of the endpoint. + * @param createMessageFn function to create the message. + * @param callbackIfStopped callback function if the endpoint is stopped. + */ + private def postMessage( endpointName: String, createMessageFn: NettyRpcEndpointRef => InboxMessage, - onStopped: () => Unit): Unit = { - val shouldCallOnStop = - synchronized { - val data = endpoints.get(endpointName) - if (stopped || data == null) { - true - } else { - data.inbox.post(createMessageFn(data.ref)) - receivers.put(data) - false - } + callbackIfStopped: () => Unit): Unit = { + val shouldCallOnStop = synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.put(data) + false } + } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - onStopped() + callbackIfStopped() } } - private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", - Runtime.getRuntime.availableProcessors()) - - private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") - - (0 until parallelism) foreach { _ => - executor.execute(new MessageLoop) - } - def stop(): Unit = { synchronized { if (stopped) { @@ -174,12 +180,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { // Stop all endpoints. This will queue all endpoints for processing by the message loops. endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) // Enqueue a message that tells the message loops to stop. - receivers.put(PoisonEndpoint) - executor.shutdown() + receivers.put(PoisonPill) + threadpool.shutdown() } def awaitTermination(): Unit = { - executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) } /** @@ -189,15 +195,27 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { endpoints.containsKey(name) } + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", + Runtime.getRuntime.availableProcessors()) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { val data = receivers.take() - if (data == PoisonEndpoint) { - // Put PoisonEndpoint back so that other MessageLoops can see it. - receivers.put(PoisonEndpoint) + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + receivers.put(PoisonPill) return } data.inbox.process(Dispatcher.this) @@ -211,8 +229,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } } - /** - * A poison endpoint that indicates MessageLoop should exit its loop. - */ - private val PoisonEndpoint = new EndpointData(null, null, null) + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala index 6061c9b8de944..fa9a3eb99b02a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -26,8 +26,8 @@ private[netty] case class ID(name: String) /** * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] */ -private[netty] class IDVerifier( - override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { +private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ID(name) => context.reply(dispatcher.verify(name)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index b669f59a2884e..c72b588db57fe 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -17,14 +17,16 @@ package org.apache.spark.rpc.netty -import java.util.LinkedList import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal +import com.google.common.annotations.VisibleForTesting + import org.apache.spark.{Logging, SparkException} import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + private[netty] sealed trait InboxMessage private[netty] case class ContentMessage( @@ -37,44 +39,40 @@ private[netty] case object OnStart extends InboxMessage private[netty] case object OnStop extends InboxMessage -/** - * A broadcast message that indicates connecting to a remote node. - */ -private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has connected. */ +private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a remote connection is lost. - */ -private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage +/** A message to tell all endpoints that a remote process has disconnected. */ +private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage -/** - * A broadcast message that indicates a network error - */ -private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) +/** A message to tell all endpoints that a network error has happened. */ +private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress) extends InboxMessage /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. - * @param endpointRef - * @param endpoint */ private[netty] class Inbox( val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) extends Logging { + val endpoint: RpcEndpoint) + extends Logging { - inbox => + inbox => // Give this an alias so we can use it more clearly in closures. @GuardedBy("this") - protected val messages = new LinkedList[InboxMessage]() + protected val messages = new java.util.LinkedList[InboxMessage]() + /** True if the inbox (and its associated endpoint) is stopped. */ @GuardedBy("this") private var stopped = false + /** Allow multiple threads to process messages at the same time. */ @GuardedBy("this") private var enableConcurrent = false + /** The number of threads processing messages for this inbox. */ @GuardedBy("this") - private var workerCount = 0 + private var numActiveThreads = 0 // OnStart should be the first message to process inbox.synchronized { @@ -87,12 +85,12 @@ private[netty] class Inbox( def process(dispatcher: Dispatcher): Unit = { var message: InboxMessage = null inbox.synchronized { - if (!enableConcurrent && workerCount != 0) { + if (!enableConcurrent && numActiveThreads != 0) { return } message = messages.poll() if (message != null) { - workerCount += 1 + numActiveThreads += 1 } else { return } @@ -101,15 +99,11 @@ private[netty] class Inbox( safelyCall(endpoint) { message match { case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } + // The partial function to call + val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive try { pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") + throw new SparkException(s"Unsupported message $message from ${_sender}") }) if (!needReply) { context.finish() @@ -121,11 +115,13 @@ private[netty] class Inbox( context.sendFailure(e) } else { context.finish() - throw e } + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e } - case OnStart => { + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { inbox.synchronized { @@ -134,24 +130,22 @@ private[netty] class Inbox( } } } - } case OnStop => - val _workCount = inbox.synchronized { - workerCount - } - assert(_workCount == 1, s"There should be only one worker but was ${_workCount}") + val activeThreads = inbox.synchronized { inbox.numActiveThreads } + assert(activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") - case Associated(remoteAddress) => + case RemoteProcessConnected(remoteAddress) => endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => + case RemoteProcessDisconnected(remoteAddress) => endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => + case RemoteProcessConnectionError(cause, remoteAddress) => endpoint.onNetworkError(cause, remoteAddress) } } @@ -159,33 +153,27 @@ private[netty] class Inbox( inbox.synchronized { // "enableConcurrent" will be set to false after `onStop` is called, so we should check it // every time. - if (!enableConcurrent && workerCount != 1) { + if (!enableConcurrent && numActiveThreads != 1) { // If we are not the only one worker, exit - workerCount -= 1 + numActiveThreads -= 1 return } message = messages.poll() if (message == null) { - workerCount -= 1 + numActiveThreads -= 1 return } } } } - def post(message: InboxMessage): Unit = { - val dropped = - inbox.synchronized { - if (stopped) { - // We already put "OnStop" into "messages", so we should drop further messages - true - } else { - messages.add(message) - false - } - } - if (dropped) { + def post(message: InboxMessage): Unit = inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages onDrop(message) + } else { + messages.add(message) + false } } @@ -203,24 +191,23 @@ private[netty] class Inbox( } } - // Visible for testing. + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + /** Called when we are dropping a message. Test cases override this to test message dropping. */ + @VisibleForTesting protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") + logWarning(s"Drop $message because $endpointRef is stopped") } - def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - + /** + * Calls action closure, and calls the endpoint's onError function in the case of exceptions. + */ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logWarning(s"Ignore error", e) + try action catch { + case NonFatal(e) => + try endpoint.onError(e) catch { + case NonFatal(ee) => logError(s"Ignoring error", ee) } - } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 75dcc02a0c5a9..21d5bb4923d1b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -26,7 +26,8 @@ import org.apache.spark.rpc.{RpcAddress, RpcCallContext} private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, override val senderAddress: RpcAddress, - needReply: Boolean) extends RpcCallContext with Logging { + needReply: Boolean) + extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -35,7 +36,7 @@ private[netty] abstract class NettyRpcCallContext( send(AskResponse(endpointRef, response)) } else { throw new IllegalStateException( - s"Cannot send $response to the sender because the sender won't handle it") + s"Cannot send $response to the sender because the sender does not expect a reply") } } @@ -63,7 +64,8 @@ private[netty] class LocalNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, needReply: Boolean, - p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + p: Promise[Any]) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { p.success(message) @@ -78,7 +80,8 @@ private[netty] class RemoteNettyRpcCallContext( endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, senderAddress: RpcAddress, - needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + needReply: Boolean) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 5522b40782d9e..89b6df76c2707 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -19,7 +19,6 @@ package org.apache.spark.rpc.netty import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy @@ -77,19 +76,19 @@ private[netty] class NettyRpcEnv( @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps: Seq[TransportServerBootstrap] = + val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) + java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) } else { - Nil + java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps.asJava) + server = transportContext.createServer(port, bootstraps) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort()) + RpcAddress(host, server.getPort) } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @@ -119,7 +118,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val promise = Promise[Any]() - dispatcher.postMessage(message, promise) + dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] @@ -148,10 +147,9 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send ${message} because RpcEnv is stopped") - } + logWarning(s"Cannot send $message because RpcEnv is stopped") } } } @@ -161,7 +159,7 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postMessage(message, p) + dispatcher.postLocalMessage(message, p) p.future.onComplete { case Success(response) => val reply = response.asInstanceOf[AskResponse] @@ -218,7 +216,7 @@ private[netty] class NettyRpcEnv( private[netty] def serialize(content: Any): Array[Byte] = { val buffer = javaSerializerInstance.serialize(content) - Arrays.copyOfRange( + java.util.Arrays.copyOfRange( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } @@ -425,7 +423,7 @@ private[netty] class NettyRpcHandler( assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = + val broadcastMessage: Option[RemoteProcessConnected] = synchronized { // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { @@ -435,7 +433,7 @@ private[netty] class NettyRpcHandler( remoteConnectionCount.put(remoteEnvAddress, count + 1) if (count == 0) { // This is the first connection, so fire "Associated" - Some(Associated(remoteEnvAddress)) + Some(RemoteProcessConnected(remoteEnvAddress)) } else { None } @@ -443,8 +441,8 @@ private[netty] class NettyRpcHandler( None } } - broadcastMessage.foreach(dispatcher.broadcastMessage) - dispatcher.postMessage(requestMessage, callback) + broadcastMessage.foreach(dispatcher.postToAll) + dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager @@ -455,12 +453,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = synchronized { - remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) } if (broadcastMessage.isEmpty) { logError(cause.getMessage, cause) } else { - dispatcher.broadcastMessage(broadcastMessage.get) + dispatcher.postToAll(broadcastMessage.get) } } else { // If the channel is closed before connecting, its remoteAddress will be null. @@ -485,7 +483,7 @@ private[netty] class NettyRpcHandler( if (count - 1 == 0) { // We lost all clients, so clean up and fire "Disassociated" remoteConnectionCount.remove(remoteEnvAddress) - Some(Disassociated(remoteEnvAddress)) + Some(RemoteProcessDisconnected(remoteEnvAddress)) } else { // Decrease the connection number of remoteEnvAddress remoteConnectionCount.put(remoteEnvAddress, count - 1) @@ -493,7 +491,7 @@ private[netty] class NettyRpcHandler( } } } - broadcastMessage.foreach(dispatcher.broadcastMessage) + broadcastMessage.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 1ed098379e299..15e7519d708c6 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -15,7 +15,6 @@ * limitations under the License. */ - package org.apache.spark.util import java.util.concurrent._ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e60c1b355a73e..bd7e51c3b5100 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1895,6 +1895,7 @@ private[spark] object Utils extends Logging { * This is expected to throw java.net.BindException on port collision. * @param conf A SparkConf used to get the maximum number of retries when binding to a port. * @param serviceName Name of the service. + * @return (service: T, port: Int) */ def startServiceOnPort[T]( startPort: Int, diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 120cf1b6fa9dc..276c077b3d13e 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -113,7 +113,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Associated(remoteAddress)) + inbox.post(RemoteProcessConnected(remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnConnectedMessage(remoteAddress) @@ -127,7 +127,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val inbox = new Inbox(endpointRef, endpoint) - inbox.post(Disassociated(remoteAddress)) + inbox.post(RemoteProcessDisconnected(remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnDisconnectedMessage(remoteAddress) @@ -142,7 +142,7 @@ class InboxSuite extends SparkFunSuite { val cause = new RuntimeException("Oops") val inbox = new Inbox(endpointRef, endpoint) - inbox.post(AssociationError(cause, remoteAddress)) + inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) inbox.process(dispatcher) endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 06ca035d199e8..f24f78b8c4542 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -45,7 +45,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) } test("connectionTerminated") { @@ -60,8 +60,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.connectionTerminated(client) - verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) - verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll( + RemoteProcessDisconnected(RpcAddress("localhost", 12345))) } } From d0cc79ccd0b4500bd6b18184a723dabc164e8abd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 13 Oct 2015 09:57:53 -0700 Subject: [PATCH 033/139] [SPARK-11030] [SQL] share the SQLTab across sessions The SQLTab will be shared by multiple sessions. If we create multiple independent SQLContexts (not using newSession()), will still see multiple SQLTabs in the Spark UI. Author: Davies Liu Closes #9048 from davies/sqlui. --- .../org/apache/spark/sql/SQLContext.scala | 23 +++++++++++++------ .../spark/sql/execution/ui/SQLListener.scala | 10 +++----- .../spark/sql/execution/ui/SQLTab.scala | 4 +--- .../sql/execution/ui/SQLListenerSuite.scala | 8 +++---- .../apache/spark/sql/hive/HiveContext.scala | 12 +++++++--- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1bd291389241a..cd937257d31a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -65,12 +65,15 @@ import org.apache.spark.util.Utils class SQLContext private[sql]( @transient val sparkContext: SparkContext, @transient protected[sql] val cacheManager: CacheManager, + @transient private[sql] val listener: SQLListener, val isRootContext: Boolean) extends org.apache.spark.Logging with Serializable { self => - def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager, true) + def this(sparkContext: SparkContext) = { + this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) + } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user @@ -97,7 +100,7 @@ class SQLContext private[sql]( /** * Returns a SQLContext as new session, with separated SQL configurations, temporary tables, - * registered functions, but sharing the same SparkContext and CacheManager. + * registered functions, but sharing the same SparkContext, CacheManager, SQLListener and SQLTab. * * @since 1.6.0 */ @@ -105,6 +108,7 @@ class SQLContext private[sql]( new SQLContext( sparkContext = sparkContext, cacheManager = cacheManager, + listener = listener, isRootContext = false) } @@ -113,11 +117,6 @@ class SQLContext private[sql]( */ protected[sql] lazy val conf = new SQLConf - // `listener` should be only used in the driver - @transient private[sql] val listener = new SQLListener(this) - sparkContext.addSparkListener(listener) - sparkContext.ui.foreach(new SQLTab(this, _)) - /** * Set Spark SQL configuration properties. * @@ -1312,4 +1311,14 @@ object SQLContext { ): InternalRow } } + + /** + * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. + */ + private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { + val listener = new SQLListener(sc.conf) + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + listener + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5779c71f64e9e..d6472400a6a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,19 +19,15 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.{JobExecutionStatus, Logging} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} -private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { +private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { - private val retainedExecutions = - sqlContext.sparkContext.conf.getInt("spark.sql.ui.retainedExecutions", 1000) + private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 0b0867f67eb6e..9c27944d42fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext import org.apache.spark.ui.{SparkUI, SparkUITab} -private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) +private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI - val listener = sqlContext.listener attachPage(new AllExecutionsPage(this)) attachPage(new ExecutionPage(this)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 7a46c69a056b1..727cf3665a871 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("basic") { - val listener = new SQLListener(sqlContext) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(sqlContext) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(sqlContext) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(sqlContext) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ddeadd3eb737d..e620d7fb82af9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,12 +40,13 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} +import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} @@ -88,12 +89,16 @@ private[hive] case class CurrentDatabase(ctx: HiveContext) class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, + @transient listener: SQLListener, @transient execHive: ClientWrapper, @transient metaHive: ClientInterface, - isRootContext: Boolean) extends SQLContext(sc, cacheManager, isRootContext) with Logging { + isRootContext: Boolean) + extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { self => - def this(sc: SparkContext) = this(sc, new CacheManager, null, null, true) + def this(sc: SparkContext) = { + this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), null, null, true) + } def this(sc: JavaSparkContext) = this(sc.sc) import org.apache.spark.sql.hive.HiveContext._ @@ -109,6 +114,7 @@ class HiveContext private[hive]( new HiveContext( sc = sc, cacheManager = cacheManager, + listener = listener, execHive = executionHive.newSession(), metaHive = metadataHive.newSession(), isRootContext = false) From 5e3868ba139f5f0b3a33361c6b884594a3ab6421 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 13 Oct 2015 10:02:21 -0700 Subject: [PATCH 034/139] [SPARK-10051] [SPARKR] Support collecting data of StructType in DataFrame Two points in this PR: 1. Originally thought was that a named R list is assumed to be a struct in SerDe. But this is problematic because some R functions will implicitly generate named lists that are not intended to be a struct when transferred by SerDe. So SerDe clients have to explicitly mark a names list as struct by changing its class from "list" to "struct". 2. SerDe is in the Spark Core module, and data of StructType is represented as GenricRow which is defined in Spark SQL module. SerDe can't import GenricRow as in maven build Spark SQL module depends on Spark Core module. So this PR adds a registration hook in SerDe to allow SQLUtils in Spark SQL module to register its functions for serialization and deserialization of StructType. Author: Sun Rui Closes #8794 from sun-rui/SPARK-10051. --- R/pkg/R/SQLContext.R | 22 +++--- R/pkg/R/deserialize.R | 10 +++ R/pkg/R/schema.R | 28 +++++++- R/pkg/R/serialize.R | 43 +++++++---- R/pkg/R/sparkR.R | 4 +- R/pkg/R/utils.R | 17 +++++ R/pkg/inst/tests/test_sparkSQL.R | 51 +++++++------ .../scala/org/apache/spark/api/r/SerDe.scala | 71 ++++++++++++++----- .../org/apache/spark/sql/api/r/SQLUtils.scala | 47 ++++++++++-- 9 files changed, 224 insertions(+), 69 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1c58fd96d750a..66c7e307212c3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -32,6 +32,7 @@ infer_type <- function(x) { numeric = "double", raw = "binary", list = "array", + struct = "struct", environment = "map", Date = "date", POSIXlt = "timestamp", @@ -44,17 +45,18 @@ infer_type <- function(x) { paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) + + paste0("array<", infer_type(x[[1]]), ">") + } else if (type == "struct") { + stopifnot(length(x) > 0) names <- names(x) - if (is.null(names)) { - paste0("array<", infer_type(x[[1]]), ">") - } else { - # StructType - types <- lapply(x, infer_type) - fields <- lapply(1:length(x), function(i) { - structField(names[[i]], types[[i]], TRUE) - }) - do.call(structType, fields) - } + stopifnot(!is.null(names)) + + type <- lapply(seq_along(x), function(i) { + paste0(names[[i]], ":", infer_type(x[[i]]), ",") + }) + type <- Reduce(paste0, type) + type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") } else if (length(x) > 1) { paste0("array<", infer_type(x[[1]]), ">") } else { diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index ce88d0b071b72..f7e56e43016ea 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -51,6 +51,7 @@ readTypedObject <- function(con, type) { "a" = readArray(con), "l" = readList(con), "e" = readEnv(con), + "s" = readStruct(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -135,6 +136,15 @@ readEnv <- function(con) { env } +# Read a field of StructType from DataFrame +# into a named list in R whose class is "struct" +readStruct <- function(con) { + names <- readObject(con) + fields <- readObject(con) + names(fields) <- names + listToStruct(fields) +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 8df1563f8ebc0..6f0e9a94e9bfa 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -136,7 +136,7 @@ checkType <- function(type) { switch (firstChar, a = { # Array type - m <- regexec("^array<(.*)>$", type) + m <- regexec("^array<(.+)>$", type) matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 2) { elemType <- matchedStrings[[1]][2] @@ -146,7 +146,7 @@ checkType <- function(type) { }, m = { # Map type - m <- regexec("^map<(.*),(.*)>$", type) + m <- regexec("^map<(.+),(.+)>$", type) matchedStrings <- regmatches(type, m) if (length(matchedStrings[[1]]) >= 3) { keyType <- matchedStrings[[1]][2] @@ -157,6 +157,30 @@ checkType <- function(type) { checkType(valueType) return() } + }, + s = { + # Struct type + m <- regexec("^struct<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + fieldsString <- matchedStrings[[1]][2] + # strsplit does not return the final empty string, so check if + # the final char is "," + if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") { + fields <- strsplit(fieldsString, ",")[[1]] + for (field in fields) { + m <- regexec("^(.+):(.+)$", field) + matchedStrings <- regmatches(field, m) + if (length(matchedStrings[[1]]) >= 3) { + fieldType <- matchedStrings[[1]][3] + checkType(fieldType) + } else { + break + } + } + return() + } + } }) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 91e6b3e5609b5..17082b4e52fcf 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -32,6 +32,21 @@ # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +getSerdeType <- function(object) { + type <- class(object)[[1]] + if (type != "list") { + type + } else { + # Check if all elements are of same type + elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) + if (length(elemType) <= 1) { + "array" + } else { + "list" + } + } +} + writeObject <- function(con, object, writeType = TRUE) { # NOTE: In R vectors have same type as objects. So we don't support # passing in vectors as arrays and instead require arrays to be passed @@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) { type <- "NULL" } } + + serdeType <- getSerdeType(object) if (writeType) { - writeType(con, type) + writeType(con, serdeType) } - switch(type, + switch(serdeType, NULL = writeVoid(con), integer = writeInt(con, object), character = writeString(con, object), @@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) { double = writeDouble(con, object), numeric = writeDouble(con, object), raw = writeRaw(con, object), + array = writeArray(con, object), list = writeList(con, object), + struct = writeList(con, object), jobj = writeJobj(con, object), environment = writeEnv(con, object), Date = writeDate(con, object), @@ -110,7 +129,7 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeGenericList(rawObj, row) + writeList(rawObj, row) rawConnectionValue(rawObj) } @@ -128,7 +147,9 @@ writeType <- function(con, class) { double = "d", numeric = "d", raw = "r", + array = "a", list = "l", + struct = "s", jobj = "j", environment = "e", Date = "D", @@ -139,15 +160,13 @@ writeType <- function(con, class) { } # Used to pass arrays where all the elements are of the same type -writeList <- function(con, arr) { - # All elements should be of same type - elemType <- unique(sapply(arr, function(elem) { class(elem) })) - stopifnot(length(elemType) <= 1) - +writeArray <- function(con, arr) { # TODO: Empty lists are given type "character" right now. # This may not work if the Java side expects array of any other type. - if (length(elemType) == 0) { + if (length(arr) == 0) { elemType <- class("somestring") + } else { + elemType <- getSerdeType(arr[[1]]) } writeType(con, elemType) @@ -161,7 +180,7 @@ writeList <- function(con, arr) { } # Used to pass arrays where the elements can be of different types -writeGenericList <- function(con, list) { +writeList <- function(con, list) { writeInt(con, length(list)) for (elem in list) { writeObject(con, elem) @@ -174,9 +193,9 @@ writeEnv <- function(con, env) { writeInt(con, len) if (len > 0) { - writeList(con, as.list(ls(env))) + writeArray(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeGenericList(con, as.list(vals)) + writeList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 3c57a44db257d..cc47110f54732 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -178,7 +178,7 @@ sparkR.init <- function( } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, + localJarPaths <- lapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -193,7 +193,7 @@ sparkR.init <- function( master, appName, as.character(sparkHome), - as.list(localJarPaths), + localJarPaths, sparkEnvirMap, sparkExecutorEnvMap), envir = .sparkREnv diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 69a2bc728f842..94f16c7ac52cc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -588,3 +588,20 @@ mergePartitions <- function(rdd, zip) { PipelinedRDD(rdd, partitionFunc) } + +# Convert a named list to struct so that +# SerDe won't confuse between a normal named list and struct +listToStruct <- function(list) { + stopifnot(class(list) == "list") + stopifnot(!is.null(names(list))) + class(list) <- "struct" + list +} + +# Convert a struct to a named list +structToList <- function(struct) { + stopifnot(class(list) == "struct") + + class(struct) <- "list" + struct +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3a04edbb4c116..af6efa40fb2f6 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -66,10 +66,7 @@ test_that("infer types and check types", { expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") expect_equal(infer_type(c(1L, 2L)), "array") expect_equal(infer_type(list(1L, 2L)), "array") - testStruct <- infer_type(list(a = 1L, b = "2")) - expect_equal(class(testStruct), "structType") - checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) - checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) + expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct") e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), "map") @@ -242,38 +239,36 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and map", { -# e <- new.env() -# assign("n", 3L, envir = e) -# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) -# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), -# c("c", "map"), c("d", "struct"))) -# expect_equal(count(df), 1) -# ldf <- collect(df) -# expect_equal(ldf[1,], l[[1]]) - - # ArrayType and MapType +test_that("create DataFrame with complex types", { e <- new.env() assign("n", 3L, envir = e) - l <- list(as.list(1:10), list("a", "b"), e) - df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + s <- listToStruct(list(a = "aa", b = 3L)) + + l <- list(as.list(1:10), list("a", "b"), e, s) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), - c("c", "map"))) + c("c", "map"), + c("d", "struct"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b", "c")) + expect_equal(names(ldf), c("a", "b", "c", "d")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] expect_equal(class(e), "environment") expect_equal(ls(e), "n") expect_equal(e$n, 3L) + + s <- ldf$d[[1]] + expect_equal(class(s), "struct") + expect_equal(s$a, "aa") + expect_equal(s$b, 3L) }) -# For test map type in DataFrame +# For test map type and struct type in DataFrame mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") @@ -308,7 +303,19 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$age, 16) expect_equal(bob$height, 176.5) - # TODO: tests for StructType after it is supported + # StructType + df <- jsonFile(sqlContext, mapTypeJsonPath) + expect_equal(dtypes(df), list(c("info", "struct"), + c("name", "string"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("info", "name")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "struct") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 0c78613e406e1..da126bac7ad1f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -27,6 +27,14 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { + type ReadObject = (DataInputStream, Char) => Object + type WriteObject = (DataOutputStream, Object) => Boolean + + var sqlSerDe: (ReadObject, WriteObject) = _ + + def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { + this.sqlSerDe = sqlSerDe + } // Type mapping from R to Java // @@ -63,11 +71,22 @@ private[spark] object SerDe { case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) + case 'a' => readArray(dis) case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) case 'j' => JVMObjectTracker.getObject(readString(dis)) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + val obj = (sqlSerDe._1)(dis, dataType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + obj + } + } } } @@ -141,7 +160,8 @@ private[spark] object SerDe { (0 until len).map(_ => readString(in)).toArray } - def readList(dis: DataInputStream): Array[_] = { + // All elements of an array must be of the same type + def readArray(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -150,26 +170,43 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) - case 'l' => { + case 'a' => + val len = readInt(dis) + (0 until len).map(_ => readArray(dis)).toArray + case 'l' => val len = readInt(dis) (0 until len).map(_ => readList(dis)).toArray - } - case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + val len = readInt(dis) + (0 until len).map { _ => + val obj = (sqlSerDe._1)(dis, arrType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + obj + } + }.toArray + } } } + // Each element of a list can be of different type. They are all represented + // as Object on JVM side + def readList(dis: DataInputStream): Array[Object] = { + val len = readInt(dis) + (0 until len).map(_ => readObject(dis)).toArray + } + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { - val keysType = readObjectType(in) - val keysLen = readInt(in) - val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - - val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => { - val valueType = readObjectType(in) - readTypedObject(in, valueType) - }) + // Keys is an array of String + val keys = readArray(in).asInstanceOf[Array[Object]] + val values = readList(in) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() @@ -338,8 +375,10 @@ private[spark] object SerDe { } case _ => - writeType(dos, "jobj") - writeJObj(dos, value) + if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + writeType(dos, "jobj") + writeJObj(dos, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index f45d119c8cfdf..b0120a8d0dc4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -22,13 +22,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} import scala.util.matching.Regex private[r] object SQLUtils { + SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) + def createSQLContext(jsc: JavaSparkContext): SQLContext = { new SQLContext(jsc) } @@ -61,15 +63,27 @@ private[r] object SQLUtils { case "boolean" => org.apache.spark.sql.types.BooleanType case "timestamp" => org.apache.spark.sql.types.TimestampType case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.*)${elemType}>\Z" => { + case r"\Aarray<(.+)${elemType}>\Z" => org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - } - case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => if (keyType != "string" && keyType != "character") { throw new IllegalArgumentException("Key type of a map must be string or character") } org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - } + case r"\Astruct<(.+)${fieldsStr}>\Z" => + if (fieldsStr(fieldsStr.length - 1) == ',') { + throw new IllegalArgumentException(s"Invaid type $dataType") + } + val fields = fieldsStr.split(",") + val structFields = fields.map { field => + field match { + case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => + createStructField(fieldName, fieldType, true) + + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + createStructType(structFields) case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } @@ -151,4 +165,27 @@ private[r] object SQLUtils { options: java.util.Map[String, String]): DataFrame = { sqlContext.read.format(source).schema(schema).options(options).load() } + + def readSqlObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 's' => + // Read StructType for DataFrame + val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] + Row.fromSeq(fields) + case _ => null + } + } + + def writeSqlObject(dos: DataOutputStream, obj: Object): Boolean = { + obj match { + // Handle struct type in DataFrame + case v: GenericRowWithSchema => + dos.writeByte('s') + SerDe.writeObject(dos, v.schema.fieldNames) + SerDe.writeObject(dos, v.values) + true + case _ => + false + } + } } From 1e0aba90b9e73834af70d196f7f869b062d98d94 Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Tue, 13 Oct 2015 10:09:05 -0700 Subject: [PATCH 035/139] [SPARK-10888] [SPARKR] Added as.DataFrame as a synonym to createDataFrame as.DataFrame is more a R-style like signature. Also, I'd like to know if we could make the context, e.g. sqlContext global, so that we do not have to specify it as an argument, when we each time create a dataframe. Author: Narine Kokhlikyan Closes #8952 from NarineK/sparkrasDataFrame. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/SQLContext.R | 17 +++++++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 95d949ee3e5a4..41986a5e7ab7d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -228,7 +228,8 @@ exportMethods("agg") export("sparkRSQL.init", "sparkRHive.init") -export("cacheTable", +export("as.DataFrame", + "cacheTable", "clearCache", "createDataFrame", "createExternalTable", diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 66c7e307212c3..399f53657a68c 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -64,21 +64,23 @@ infer_type <- function(x) { } } -#' Create a DataFrame from an RDD +#' Create a DataFrame #' -#' Converts an RDD to a DataFrame by infer the types. +#' Converts R data.frame or list into DataFrame. #' #' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame +#' @rdname createDataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlContext, rdd) +#' df1 <- as.DataFrame(sqlContext, iris) +#' df2 <- as.DataFrame(sqlContext, list(3,4,5,6)) +#' df3 <- createDataFrame(sqlContext, iris) #' } # TODO(davies): support sampling and infer type from NA @@ -151,6 +153,13 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 dataFrame(sdf) } +#' @rdname createDataFrame +#' @aliases createDataFrame +#' @export +as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { + createDataFrame(sqlContext, data, schema, samplingRatio) +} + # toDF # # Converts an RDD to a DataFrame by infer the types. diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index af6efa40fb2f6..b599994854670 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -89,17 +89,28 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) + dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") expect_equal(count(df), 10) + expect_equal(count(dfAsDF), 10) expect_equal(nrow(df), 10) + expect_equal(nrow(dfAsDF), 10) expect_equal(ncol(df), 2) + expect_equal(ncol(dfAsDF), 2) expect_equal(dim(df), c(10, 2)) + expect_equal(dim(dfAsDF), c(10, 2)) expect_equal(columns(df), c("a", "b")) + expect_equal(columns(dfAsDF), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) + dfAsDF <- as.DataFrame(sqlContext, rdd) expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") expect_equal(columns(df), c("_1", "_2")) + expect_equal(columns(dfAsDF), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) @@ -130,9 +141,13 @@ test_that("create DataFrame from RDD", { schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) df2 <- createDataFrame(sqlContext, df.toRDD, schema) + df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), From f7f28ee7a513c262d52cf433d25fbf06df9bd1f1 Mon Sep 17 00:00:00 2001 From: Adrian Zhuang Date: Tue, 13 Oct 2015 10:21:07 -0700 Subject: [PATCH 036/139] [SPARK-10913] [SPARKR] attach() function support Bring the change code up to date. Author: Adrian Zhuang Author: adrian555 Closes #9031 from adrian555/attach2. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 30 ++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 20 ++++++++++++++++++++ 4 files changed, 55 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 41986a5e7ab7d..ed9cd94e03b13 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,6 +23,7 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "attach", "cache", "collect", "columns", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1b9137e6c7934..e0ce056243585 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1881,3 +1881,33 @@ setMethod("as.data.frame", } collect(x) }) + +#' The specified DataFrame is attached to the R search path. This means that +#' the DataFrame is searched by R when evaluating a variable, so columns in +#' the DataFrame can be accessed by simply giving their names. +#' +#' @rdname attach +#' @title Attach DataFrame to R search path +#' @param what (DataFrame) The DataFrame to attach +#' @param pos (integer) Specify position in search() where to attach. +#' @param name (character) Name to use for the attached DataFrame. Names +#' starting with package: are reserved for library. +#' @param warn.conflicts (logical) If TRUE, warnings are printed about conflicts +#' from attaching the database, unless that DataFrame contains an object +#' @examples +#' \dontrun{ +#' attach(irisDf) +#' summary(Sepal_Width) +#' } +#' @seealso \link{detach} +setMethod("attach", + signature(what = "DataFrame"), + function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { + cols <- columns(what) + stopifnot(length(cols) > 0) + newEnv <- new.env() + for (i in 1:length(cols)) { + assign(x = cols[i], value = what[, cols[i]], envir = newEnv) + } + attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8fad17026c06f..c106a0024583e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1003,3 +1003,7 @@ setGeneric("rbind", signature = "...") #' @rdname as.data.frame #' @export setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") \ No newline at end of file diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b599994854670..d5509e475de05 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1405,6 +1405,26 @@ test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF2), collect(irisDF2)) }) +test_that("attach() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_error(age) + attach(df) + expect_is(age, "DataFrame") + expected_age <- data.frame(age = c(NA, 30, 19)) + expect_equal(head(age), expected_age) + stat <- summary(age) + expect_equal(collect(stat)[5, "age"], "30") + age <- age$age + 1 + expect_is(age, "Column") + rm(age) + stat2 <- summary(age) + expect_equal(collect(stat2)[5, "age"], "30") + detach("df") + stat3 <- summary(df[, "age"]) + expect_equal(collect(stat3)[5, "age"], "30") + expect_error(age) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From c75f058b72d492d6de898957b3058f242d70dd8a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 13 Oct 2015 12:03:46 -0700 Subject: [PATCH 037/139] [PYTHON] [MINOR] List modules in PySpark tests when given bad name Output list of supported modules for python tests in error message when given bad module name. CC: davies Author: Joseph K. Bradley Closes #9088 from jkbradley/python-tests-modules. --- python/run-tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/run-tests.py b/python/run-tests.py index fd56c7ab6e0e2..152f5cc98d0fd 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -167,7 +167,8 @@ def main(): if module_name in python_modules: modules_to_test.append(python_modules[module_name]) else: - print("Error: unrecognized module %s" % module_name) + print("Error: unrecognized module '%s'. Supported modules: %s" % + (module_name, ", ".join(python_modules))) sys.exit(-1) LOGGER.info("Will test against the following Python executables: %s", python_execs) LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) From 2b574f52d7bf51b1fe2a73086a3735b633e9083f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 13 Oct 2015 13:24:10 -0700 Subject: [PATCH 038/139] [SPARK-7402] [ML] JSON SerDe for standard param types This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley Author: Xiangrui Meng Closes #9090 from mengxr/SPARK-7402. --- .../org/apache/spark/ml/param/params.scala | 169 ++++++++++++++++++ .../apache/spark/ml/param/ParamsSuite.scala | 114 ++++++++++++ 2 files changed, 283 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ec98b05e13b89..8361406f87299 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,6 +24,9 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable @@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ + def jsonEncode(value: T): String = { + value match { + case x: String => + compact(render(JString(x))) + case _ => + throw new NotImplementedError( + "The default jsonEncode only supports string. " + + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") + } + } + + /** Decodes a param value from JSON. */ + def jsonDecode(json: String): T = { + parse(json) match { + case JString(x) => + x.asInstanceOf[T] + case _ => + throw new NotImplementedError( + "The default jsonDecode only supports string. " + + s"${this.getClass.getName} must override jsonDecode to support its value type.") + } + } + override final def toString: String = s"${parent}__$name" override final def hashCode: Int = toString.## @@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) + + override def jsonEncode(value: Double): String = { + compact(render(DoubleParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Double = { + DoubleParam.jValueDecode(parse(json)) + } +} + +private[param] object DoubleParam { + /** Encodes a param value into JValue. */ + def jValueEncode(value: Double): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Double.NegativeInfinity => + JString("-Inf") + case Double.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Double = { + jValue match { + case JString("NaN") => + Double.NaN + case JString("-Inf") => + Double.NegativeInfinity + case JString("Inf") => + Double.PositiveInfinity + case JDouble(x) => + x + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Double.") + } + } } /** @@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) + + override def jsonEncode(value: Int): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Int = { + implicit val formats = DefaultFormats + parse(json).extract[Int] + } } /** @@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) + + override def jsonEncode(value: Float): String = { + compact(render(FloatParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Float = { + FloatParam.jValueDecode(parse(json)) + } +} + +private object FloatParam { + + /** Encodes a param value into JValue. */ + def jValueEncode(value: Float): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Float.NegativeInfinity => + JString("-Inf") + case Float.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Float = { + jValue match { + case JString("NaN") => + Float.NaN + case JString("-Inf") => + Float.NegativeInfinity + case JString("Inf") => + Float.PositiveInfinity + case JDouble(x) => + x.toFloat + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Float.") + } + } } /** @@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) + + override def jsonEncode(value: Long): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Long = { + implicit val formats = DefaultFormats + parse(json).extract[Long] + } } /** @@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) + + override def jsonEncode(value: Boolean): String = { + compact(render(JBool(value))) + } + + override def jsonDecode(json: String): Boolean = { + implicit val formats = DefaultFormats + parse(json).extract[Boolean] + } } /** @@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) + + override def jsonEncode(value: Array[String]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[String] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[String]].toArray + } } /** @@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = w(value.asScala.map(_.asInstanceOf[Double]).toArray) + + override def jsonEncode(value: Array[Double]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(DoubleParam.jValueEncode))) + } + + override def jsonDecode(json: String): Array[Double] = { + parse(json) match { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].") + } + } } /** @@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = w(value.asScala.map(_.asInstanceOf[Int]).toArray) + + override def jsonEncode(value: Array[Int]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[Int] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[Int]].toArray + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a2ea279f5d5e4..eeb03dba2f825 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite class ParamsSuite extends SparkFunSuite { + test("json encode/decode") { + val dummy = new Params { + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override val uid: String = "dummy" + } + + { // BooleanParam + val param = new BooleanParam(dummy, "name", "doc") + for (value <- Seq(true, false)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntParam + val param = new IntParam(dummy, "name", "doc") + for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // LongParam + val param = new LongParam(dummy, "name", "doc") + for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // FloatParam + val param = new FloatParam(dummy, "name", "doc") + for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f, + Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // DoubleParam + val param = new DoubleParam(dummy, "name", "doc") + for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0, + Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // StringParam + val param = new Param[String](dummy, "name", "doc") + // Currently we do not support null. + for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntArrayParam + val param = new IntArrayParam(dummy, "name", "doc") + val values: Seq[Array[Int]] = Seq( + Array(), + Array(1), + Array(Int.MinValue, 0, Int.MaxValue)) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // DoubleArrayParam + val param = new DoubleArrayParam(dummy, "name", "doc") + val values: Seq[Array[Double]] = Seq( + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + + { // StringArrayParam + val param = new StringArrayParam(dummy, "name", "doc") + val values: Seq[Array[String]] = Seq( + Array(), + Array(""), + Array("", "1", "abc", "quote\"", "newline\n")) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + } + test("param") { val solver = new TestParams() val uid = solver.uid From b3ffac5178795f2d8e7908b3e77e8e89f50b5f6f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 13 Oct 2015 13:49:59 -0700 Subject: [PATCH 039/139] [SPARK-10983] Unified memory manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch unifies the memory management of the storage and execution regions such that either side can borrow memory from each other. When memory pressure arises, storage will be evicted in favor of execution. To avoid regressions in cases where storage is crucial, we dynamically allocate a fraction of space for storage that execution cannot evict. Several configurations are introduced: - **spark.memory.fraction (default 0.75)**: ​fraction of the heap space used for execution and storage. The lower this is, the more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation in the case of sparse, unusually large records. - **spark.memory.storageFraction (default 0.5)**: size of the storage region within the space set aside by `s​park.memory.fraction`. ​Cached data may only be evicted if total storage exceeds this region. - **spark.memory.useLegacyMode (default false)**: whether to use the memory management that existed in Spark 1.5 and before. This is mainly for backward compatibility. For a detailed description of the design, see [SPARK-10000](https://issues.apache.org/jira/browse/SPARK-10000). This patch builds on top of the `MemoryManager` interface introduced in #9000. Author: Andrew Or Closes #9084 from andrewor14/unified-memory-manager. --- .../scala/org/apache/spark/SparkConf.scala | 23 +- .../scala/org/apache/spark/SparkEnv.scala | 11 +- .../apache/spark/memory/MemoryManager.scala | 83 +++++-- .../spark/memory/StaticMemoryManager.scala | 105 +++------ .../spark/memory/UnifiedMemoryManager.scala | 141 ++++++++++++ .../spark/shuffle/ShuffleMemoryManager.scala | 38 ++-- .../apache/spark/storage/BlockManager.scala | 4 + .../apache/spark/storage/MemoryStore.scala | 121 ++++++---- .../collection/ExternalAppendOnlyMap.scala | 10 - .../org/apache/spark/DistributedSuite.scala | 7 +- .../scala/org/apache/spark/ShuffleSuite.scala | 6 +- .../spark/memory/MemoryManagerSuite.scala | 133 +++++++++++ .../memory/StaticMemoryManagerSuite.scala | 105 ++++----- .../memory/UnifiedMemoryManagerSuite.scala | 208 ++++++++++++++++++ .../shuffle/ShuffleMemoryManagerSuite.scala | 5 +- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 3 - .../ExternalAppendOnlyMapSuite.scala | 9 +- .../util/collection/ExternalSorterSuite.scala | 23 +- docs/configuration.md | 99 ++++++--- .../execution/TestShuffleMemoryManager.scala | 10 +- .../execution/UnsafeRowSerializerSuite.scala | 2 +- 21 files changed, 840 insertions(+), 306 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b344b5e173d67..1a0ac3d01759c 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -418,16 +418,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate memory fractions - val memoryKeys = Seq( + val deprecatedMemoryKeys = Seq( "spark.storage.memoryFraction", "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") + val memoryKeys = Seq( + "spark.memory.fraction", + "spark.memory.storageFraction") ++ + deprecatedMemoryKeys for (key <- memoryKeys) { val value = getDouble(key, 0.5) if (value > 1 || value < 0) { - throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + throw new IllegalArgumentException(s"$key should be between 0 and 1 (was '$value').") + } + } + + // Warn against deprecated memory fractions (unless legacy memory management mode is enabled) + val legacyMemoryManagementKey = "spark.memory.useLegacyMode" + val legacyMemoryManagement = getBoolean(legacyMemoryManagementKey, false) + if (!legacyMemoryManagement) { + val keyset = deprecatedMemoryKeys.toSet + val detected = settings.keys().asScala.filter(keyset.contains) + if (detected.nonEmpty) { + logWarning("Detected deprecated memory fraction settings: " + + detected.mkString("[", ", ", "]") + ". As of Spark 1.6, execution and storage " + + "memory management are unified. All memory fractions used in the old model are " + + "now deprecated and no longer read. If you wish to use the old memory management, " + + s"you may explicitly enable `$legacyMemoryManagementKey` (not recommended).") } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index df3d84a1f08e9..c32998345145a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.memory.{MemoryManager, StaticMemoryManager} +import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} @@ -335,7 +335,14 @@ object SparkEnv extends Logging { val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val memoryManager = new StaticMemoryManager(conf) + val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) + val memoryManager: MemoryManager = + if (useLegacyMemoryManager) { + new StaticMemoryManager(conf) + } else { + new UnifiedMemoryManager(conf) + } + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores) val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 4bf73b696920d..7168ac549106f 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.memory import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} @@ -29,7 +30,7 @@ import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} * sorts and aggregations, while storage memory refers to that used for caching and propagating * internal data across the cluster. There exists one of these per JVM. */ -private[spark] abstract class MemoryManager { +private[spark] abstract class MemoryManager extends Logging { // The memory store used to evict cached blocks private var _memoryStore: MemoryStore = _ @@ -40,19 +41,38 @@ private[spark] abstract class MemoryManager { _memoryStore } + // Amount of execution/storage memory in use, accesses must be synchronized on `this` + protected var _executionMemoryUsed: Long = 0 + protected var _storageMemoryUsed: Long = 0 + /** * Set the [[MemoryStore]] used by this manager to evict cached blocks. * This must be set after construction due to initialization ordering constraints. */ - def setMemoryStore(store: MemoryStore): Unit = { + final def setMemoryStore(store: MemoryStore): Unit = { _memoryStore = store } /** - * Acquire N bytes of memory for execution. + * Total available memory for execution, in bytes. + */ + def maxExecutionMemory: Long + + /** + * Total available memory for storage, in bytes. + */ + def maxStorageMemory: Long + + // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. * @return number of bytes successfully granted (<= N). */ - def acquireExecutionMemory(numBytes: Long): Long + def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. @@ -66,52 +86,73 @@ private[spark] abstract class MemoryManager { /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * + * This extra method allows subclasses to differentiate behavior between acquiring storage + * memory and acquiring unroll memory. For instance, the memory management model in Spark + * 1.5 and before places a limit on the amount of space that can be freed from unrolling. * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * * @return whether all N bytes were successfully granted. */ def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, evictedBlocks) + } /** * Release N bytes of execution memory. */ - def releaseExecutionMemory(numBytes: Long): Unit + def releaseExecutionMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _executionMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of execution " + + s"memory when we only have ${_executionMemoryUsed} bytes") + _executionMemoryUsed = 0 + } else { + _executionMemoryUsed -= numBytes + } + } /** * Release N bytes of storage memory. */ - def releaseStorageMemory(numBytes: Long): Unit + def releaseStorageMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _storageMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of storage " + + s"memory when we only have ${_storageMemoryUsed} bytes") + _storageMemoryUsed = 0 + } else { + _storageMemoryUsed -= numBytes + } + } /** * Release all storage memory acquired. */ - def releaseStorageMemory(): Unit + def releaseAllStorageMemory(): Unit = synchronized { + _storageMemoryUsed = 0 + } /** * Release N bytes of unroll memory. */ - def releaseUnrollMemory(numBytes: Long): Unit - - /** - * Total available memory for execution, in bytes. - */ - def maxExecutionMemory: Long - - /** - * Total available memory for storage, in bytes. - */ - def maxStorageMemory: Long + def releaseUnrollMemory(numBytes: Long): Unit = synchronized { + releaseStorageMemory(numBytes) + } /** * Execution memory currently in use, in bytes. */ - def executionMemoryUsed: Long + final def executionMemoryUsed: Long = synchronized { + _executionMemoryUsed + } /** * Storage memory currently in use, in bytes. */ - def storageMemoryUsed: Long + final def storageMemoryUsed: Long = synchronized { + _storageMemoryUsed + } } diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 150445edb9578..fa44f3723415d 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.memory import scala.collection.mutable -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus} @@ -34,17 +34,7 @@ private[spark] class StaticMemoryManager( conf: SparkConf, override val maxExecutionMemory: Long, override val maxStorageMemory: Long) - extends MemoryManager with Logging { - - // Max number of bytes worth of blocks to evict when unrolling - private val maxMemoryToEvictForUnroll: Long = { - (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong - } - - // Amount of execution / storage memory in use - // Accesses must be synchronized on `this` - private var _executionMemoryUsed: Long = 0 - private var _storageMemoryUsed: Long = 0 + extends MemoryManager { def this(conf: SparkConf) { this( @@ -53,11 +43,19 @@ private[spark] class StaticMemoryManager( StaticMemoryManager.getMaxStorageMemory(conf)) } + // Max number of bytes worth of blocks to evict when unrolling + private val maxMemoryToEvictForUnroll: Long = { + (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong + } + /** * Acquire N bytes of memory for execution. * @return number of bytes successfully granted (<= N). */ - override def acquireExecutionMemory(numBytes: Long): Long = synchronized { + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + assert(numBytes >= 0) assert(_executionMemoryUsed <= maxExecutionMemory) val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) _executionMemoryUsed += bytesToGrant @@ -72,7 +70,7 @@ private[spark] class StaticMemoryManager( override def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) } @@ -88,7 +86,7 @@ private[spark] class StaticMemoryManager( override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { val currentUnrollMemory = memoryStore.currentUnrollMemory val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) val numBytesToFree = math.min(numBytes, maxNumBytesToFree) @@ -108,71 +106,16 @@ private[spark] class StaticMemoryManager( blockId: BlockId, numBytesToAcquire: Long, numBytesToFree: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - // Note: Keep this outside synchronized block to avoid potential deadlocks! + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - synchronized { - assert(_storageMemoryUsed <= maxStorageMemory) - val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory - if (enoughMemory) { - _storageMemoryUsed += numBytesToAcquire - } - enoughMemory - } - } - - /** - * Release N bytes of execution memory. - */ - override def releaseExecutionMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _executionMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of execution " + - s"memory when we only have ${_executionMemoryUsed} bytes") - _executionMemoryUsed = 0 - } else { - _executionMemoryUsed -= numBytes - } - } - - /** - * Release N bytes of storage memory. - */ - override def releaseStorageMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _storageMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of storage " + - s"memory when we only have ${_storageMemoryUsed} bytes") - _storageMemoryUsed = 0 - } else { - _storageMemoryUsed -= numBytes + assert(_storageMemoryUsed <= maxStorageMemory) + val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory + if (enoughMemory) { + _storageMemoryUsed += numBytesToAcquire } - } - - /** - * Release all storage memory acquired. - */ - override def releaseStorageMemory(): Unit = synchronized { - _storageMemoryUsed = 0 - } - - /** - * Release N bytes of unroll memory. - */ - override def releaseUnrollMemory(numBytes: Long): Unit = { - releaseStorageMemory(numBytes) - } - - /** - * Amount of execution memory currently in use, in bytes. - */ - override def executionMemoryUsed: Long = synchronized { - _executionMemoryUsed - } - - /** - * Amount of storage memory currently in use, in bytes. - */ - override def storageMemoryUsed: Long = synchronized { - _storageMemoryUsed + enoughMemory } } @@ -184,9 +127,10 @@ private[spark] object StaticMemoryManager { * Return the total amount of memory available for the storage region, in bytes. */ private def getMaxStorageMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + (systemMaxMemory * memoryFraction * safetyFraction).toLong } @@ -194,9 +138,10 @@ private[spark] object StaticMemoryManager { * Return the total amount of memory available for the execution region, in bytes. */ private def getMaxExecutionMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + (systemMaxMemory * memoryFraction * safetyFraction).toLong } } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala new file mode 100644 index 0000000000000..5bf78d5b674b3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} + + +/** + * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that + * either side can borrow memory from the other. + * + * The region shared between execution and storage is a fraction of the total heap space + * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary + * within this space is further determined by `spark.memory.storageFraction` (default 0.5). + * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. + * + * Storage can borrow as much execution memory as is free until execution reclaims its space. + * When this happens, cached blocks will be evicted from memory until sufficient borrowed + * memory is released to satisfy the execution memory request. + * + * Similarly, execution can borrow as much storage memory as is free. However, execution + * memory is *never* evicted by storage due to the complexities involved in implementing this. + * The implication is that attempts to cache blocks may fail if execution has already eaten + * up most of the storage space, in which case the new blocks will be evicted immediately + * according to their respective storage levels. + */ +private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager { + + def this(conf: SparkConf) { + this(conf, UnifiedMemoryManager.getMaxMemory(conf)) + } + + /** + * Size of the storage region, in bytes. + * + * This region is not statically reserved; execution can borrow from it if necessary. + * Cached blocks can be evicted only if actual storage memory usage exceeds this region. + */ + private val storageRegionSize: Long = { + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong + } + + /** + * Total amount of memory, in bytes, not currently occupied by either execution or storage. + */ + private def totalFreeMemory: Long = synchronized { + assert(_executionMemoryUsed <= maxMemory) + assert(_storageMemoryUsed <= maxMemory) + assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory) + maxMemory - _executionMemoryUsed - _storageMemoryUsed + } + + /** + * Total available memory for execution, in bytes. + * In this model, this is equivalent to the amount of memory not occupied by storage. + */ + override def maxExecutionMemory: Long = synchronized { + maxMemory - _storageMemoryUsed + } + + /** + * Total available memory for storage, in bytes. + * In this model, this is equivalent to the amount of memory not occupied by execution. + */ + override def maxStorageMemory: Long = synchronized { + maxMemory - _executionMemoryUsed + } + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * + * This method evicts blocks only up to the amount of memory borrowed by storage. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return number of bytes successfully granted (<= N). + */ + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + assert(numBytes >= 0) + val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize) + // If there is not enough free memory AND storage has borrowed some execution memory, + // then evict as much memory borrowed by storage as needed to grant this request + val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0 + if (shouldEvictStorage) { + val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage) + memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks) + } + val bytesToGrant = math.min(numBytes, totalFreeMemory) + _executionMemoryUsed += bytesToGrant + bytesToGrant + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(numBytes >= 0) + memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks) + val enoughMemory = totalFreeMemory >= numBytes + if (enoughMemory) { + _storageMemoryUsed += numBytes + } + enoughMemory + } + +} + +private object UnifiedMemoryManager { + + /** + * Return the total amount of memory shared between execution and storage, in bytes. + */ + private def getMaxMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) + (systemMaxMemory * memoryFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index bb64bb3f35df0..aaf543ce9232a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -18,11 +18,13 @@ package org.apache.spark.shuffle import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.annotations.VisibleForTesting import org.apache.spark._ import org.apache.spark.memory.{StaticMemoryManager, MemoryManager} +import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.unsafe.array.ByteArrayMethods /** @@ -36,8 +38,8 @@ import org.apache.spark.unsafe.array.ByteArrayMethods * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access on "this" to mutate state and using - * wait() and notifyAll() to signal changes. + * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state + * and using wait() and notifyAll() to signal changes. * * Use `ShuffleMemoryManager.create()` factory method to create a new instance. * @@ -51,7 +53,6 @@ class ShuffleMemoryManager protected ( extends Logging { private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes - private val maxMemory = memoryManager.maxExecutionMemory private def currentTaskAttemptId(): Long = { // In case this is called on the driver, return an invalid task attempt id. @@ -65,7 +66,7 @@ class ShuffleMemoryManager protected ( * total memory pool (where N is the # of active tasks) before it is forced to spill. This can * happen if the number of tasks increases but an older task had a lot of memory already. */ - def tryToAcquire(numBytes: Long): Long = synchronized { + def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) @@ -73,15 +74,18 @@ class ShuffleMemoryManager protected ( // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire if (!taskMemory.contains(taskAttemptId)) { taskMemory(taskAttemptId) = 0L - notifyAll() // Will later cause waiting tasks to wake up and check numTasks again + // This will later cause waiting tasks to wake up and check numTasks again + memoryManager.notifyAll() } // Keep looping until we're either sure that we don't want to grant this request (because this // task would have more than 1 / numActiveTasks of the memory) or we have enough free // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot while (true) { val numActiveTasks = taskMemory.keys.size val curMem = taskMemory(taskAttemptId) + val maxMemory = memoryManager.maxExecutionMemory val freeMemory = maxMemory - taskMemory.values.sum // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; @@ -99,7 +103,7 @@ class ShuffleMemoryManager protected ( } else { logInfo( s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") - wait() + memoryManager.wait() } } else { return acquire(toGrant) @@ -112,15 +116,23 @@ class ShuffleMemoryManager protected ( * Acquire N bytes of execution memory from the memory manager for the current task. * @return number of bytes actually acquired (<= N). */ - private def acquire(numBytes: Long): Long = synchronized { + private def acquire(numBytes: Long): Long = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() - val acquired = memoryManager.acquireExecutionMemory(numBytes) + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + // TODO: just do this in `acquireExecutionMemory` (SPARK-10985) + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } taskMemory(taskAttemptId) += acquired acquired } /** Release numBytes bytes for the current task. */ - def release(numBytes: Long): Unit = synchronized { + def release(numBytes: Long): Unit = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { @@ -129,20 +141,20 @@ class ShuffleMemoryManager protected ( } taskMemory(taskAttemptId) -= numBytes memoryManager.releaseExecutionMemory(numBytes) - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed } /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisTask(): Unit = synchronized { + def releaseMemoryForThisTask(): Unit = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() taskMemory.remove(taskAttemptId).foreach { numBytes => memoryManager.releaseExecutionMemory(numBytes) } - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed } /** Returns the memory consumption, in bytes, for the current task */ - def getMemoryConsumptionForThisTask(): Long = synchronized { + def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() taskMemory.getOrElse(taskAttemptId, 0L) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9f5bd2abbdc5d..c374b93766225 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -91,6 +91,10 @@ private[spark] class BlockManager( } memoryManager.setMemoryStore(memoryStore) + // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. + // However, since we use this only for reporting and logging, what we actually want here is + // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need + // to revisit whether reporting this value as the "max" is intuitive to the user. private val maxMemory = memoryManager.maxStorageMemory private[spark] diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 35c57b923c43a..4dbac388e098b 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -37,15 +37,14 @@ private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager) extends BlockStore(blockManager) { + // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and + // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! + private val conf = blockManager.conf private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) - private val maxMemory = memoryManager.maxStorageMemory - - // Ensure only one thread is putting, and if necessary, dropping blocks at any given time - private val accountingLock = new Object // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) - // All accesses of this map are assumed to have manually synchronized on `accountingLock` + // All accesses of this map are assumed to have manually synchronized on `memoryManager` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. // Pending unroll memory refers to the intermediate memory occupied by a task @@ -60,6 +59,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo private val unrollMemoryThreshold: Long = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + /** Total amount of memory available for storage, in bytes. */ + private def maxMemory: Long = memoryManager.maxStorageMemory + if (maxMemory < unrollMemoryThreshold) { logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + @@ -75,7 +77,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * Amount of storage memory, in bytes, used for caching blocks. * This does not include memory used for unrolling. */ - private def blocksMemoryUsed: Long = memoryUsed - currentUnrollMemory + private def blocksMemoryUsed: Long = memoryManager.synchronized { + memoryUsed - currentUnrollMemory + } override def getSize(blockId: BlockId): Long = { entries.synchronized { @@ -208,7 +212,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } } - override def remove(blockId: BlockId): Boolean = { + override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { val entry = entries.synchronized { entries.remove(blockId) } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) @@ -220,11 +224,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } } - override def clear() { + override def clear(): Unit = memoryManager.synchronized { entries.synchronized { entries.clear() } - memoryManager.releaseStorageMemory() + unrollMemoryMap.clear() + pendingUnrollMemoryMap.clear() + memoryManager.releaseAllStorageMemory() logInfo("MemoryStore cleared") } @@ -299,22 +305,23 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } } finally { - // If we return an array, the values returned will later be cached in `tryToPut`. - // In this case, we should release the memory after we cache the block there. - // Otherwise, if we return an iterator, we release the memory reserved here - // later when the task finishes. + // If we return an array, the values returned here will be cached in `tryToPut` later. + // In this case, we should release the memory only after we cache the block there. if (keepUnrolling) { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - // Here, we transfer memory from unroll to pending unroll because we expect to cache this - // block in `tryToPut`. We do not release and re-acquire memory from the MemoryManager in - // order to avoid race conditions where another component steals the memory that we're - // trying to transfer. + memoryManager.synchronized { + // Since we continue to hold onto the array until we actually cache it, we cannot + // release the unroll memory yet. Instead, we transfer it to pending unroll memory + // so `tryToPut` can further transfer it to normal storage memory later. + // TODO: we can probably express this without pending unroll memory (SPARK-10907) val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved unrollMemoryMap(taskAttemptId) -= amountToTransferToPending pendingUnrollMemoryMap(taskAttemptId) = pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending } + } else { + // Otherwise, if we return an iterator, we can only release the unroll memory when + // the task finishes since we don't know when the iterator will be consumed. } } } @@ -343,7 +350,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be * created to avoid OOM since it may be a big ByteBuffer. * - * Synchronize on `accountingLock` to ensure that all the put requests and its associated block + * Synchronize on `memoryManager` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for * another block. @@ -365,16 +372,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * for freeing up more space for another block that needs to be put. Only then the actually * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */ - accountingLock.synchronized { + memoryManager.synchronized { // Note: if we have previously unrolled this block successfully, then pending unroll // memory should be non-zero. This is the amount that we already reserved during the // unrolling process. In this case, we can just reuse this space to cache our block. - // - // Note: the StaticMemoryManager counts unroll memory as storage memory. Here, the - // synchronization on `accountingLock` guarantees that the release of unroll memory and - // acquisition of storage memory happens atomically. However, if storage memory is acquired - // outside of MemoryStore or if unroll memory is counted as execution memory, then we will - // have to revisit this assumption. See SPARK-10983 for more context. + // The synchronization on `memoryManager` here guarantees that the release and acquire + // happen atomically. This relies on the assumption that all memory acquisitions are + // synchronized on the same lock. releasePendingUnrollMemoryForThisTask() val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) if (enoughMemory) { @@ -401,34 +405,62 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } } + /** + * Try to free up a given amount of space by evicting existing blocks. + * + * @param space the amount of memory to free, in bytes + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def ensureFreeSpace( + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + ensureFreeSpace(None, space, droppedBlocks) + } + + /** + * Try to free up a given amount of space to store a block by evicting existing ones. + * + * @param space the amount of memory to free, in bytes + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def ensureFreeSpace( + blockId: BlockId, + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + ensureFreeSpace(Some(blockId), space, droppedBlocks) + } + /** * Try to free up a given amount of space to store a particular block, but can fail if * either the block is bigger than our memory or it would require replacing another block * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * @param blockId the ID of the block we are freeing space for + * @param blockId the ID of the block we are freeing space for, if any * @param space the size of this block * @param droppedBlocks a holder for blocks evicted in the process - * @return whether there is enough free space. + * @return whether the requested free space is freed. */ - private[spark] def ensureFreeSpace( - blockId: BlockId, + private def ensureFreeSpace( + blockId: Option[BlockId], space: Long, droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - accountingLock.synchronized { + memoryManager.synchronized { val freeMemory = maxMemory - memoryUsed - val rddToAdd = getRddId(blockId) + val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] var selectedMemory = 0L - logInfo(s"Ensuring $space bytes of free space for block $blockId " + + logInfo(s"Ensuring $space bytes of free space " + + blockId.map { id => s"for block $id" }.getOrElse("") + s"(free: $freeMemory, max: $maxMemory)") // Fail fast if the block simply won't fit if (space > maxMemory) { - logInfo(s"Will not store $blockId as the required space " + - s"($space bytes) than our memory limit ($maxMemory bytes)") + logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + + s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") return false } @@ -471,8 +503,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } true } else { - logInfo(s"Will not store $blockId as it would require dropping another block " + - "from the same RDD") + blockId.foreach { id => + logInfo(s"Will not store $id as it would require dropping another block " + + "from the same RDD") + } false } } @@ -495,8 +529,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo blockId: BlockId, memory: Long, droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - accountingLock.synchronized { - // Note: all acquisitions of unroll memory must be synchronized on `accountingLock` + memoryManager.synchronized { val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) if (success) { val taskAttemptId = currentTaskAttemptId() @@ -512,7 +545,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo */ def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { + memoryManager.synchronized { if (unrollMemoryMap.contains(taskAttemptId)) { val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { @@ -531,7 +564,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo */ def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { + memoryManager.synchronized { if (pendingUnrollMemoryMap.contains(taskAttemptId)) { val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { @@ -548,21 +581,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo /** * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ - def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = memoryManager.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** * Return the number of tasks currently unrolling blocks. */ - private def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 29c5732f5a8c1..6a96b5dc12684 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -48,16 +48,6 @@ import org.apache.spark.executor.ShuffleWriteMetrics * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk * writes. This may lead to a performance regression compared to the normal case of using the * non-spilling AppendOnlyMap. - * - * Two parameters control the memory threshold: - * - * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing - * these maps as a fraction of the executor's total memory. Since each concurrently running - * task maintains one map, the actual threshold for each map is this quantity divided by the - * number of running tasks. - * - * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of - * this threshold, in case map size estimation is not sufficiently accurate. */ @DeveloperApi class ExternalAppendOnlyMap[K, V, C]( diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 600c1403b0344..34a4bb968e732 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -213,11 +213,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("compute when only some partitions fit in memory") { - val conf = new SparkConf().set("spark.storage.memoryFraction", "0.01") - sc = new SparkContext(clusterUrl, "test", conf) - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions - // to make sure that *some* of them do fit though + sc = new SparkContext(clusterUrl, "test", new SparkConf) + // TODO: verify that only a subset of partitions fit in memory (SPARK-11078) val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) assert(data.count() === 4000000) assert(data.count() === 4000000) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d91b799ecfc08..4a0877d86f2c6 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -247,11 +247,13 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) - .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() sc = new SparkContext(myConf) + val diskBlockManager = sc.env.blockManager.diskBlockManager try { - sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() + assert(diskBlockManager.getAllFiles().isEmpty) + sc.parallelize(0 until 10).map(i => (i / 4, i)).groupByKey().collect() + assert(diskBlockManager.getAllFiles().nonEmpty) } catch { case e: Exception => val errMsg = s"Failed with spark.shuffle.spill.compress=$shuffleSpillCompress," + diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala new file mode 100644 index 0000000000000..36e4566310715 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import java.util.concurrent.atomic.AtomicLong + +import org.mockito.Matchers.{any, anyLong} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.MemoryStore + + +/** + * Helper trait for sharing code among [[MemoryManager]] tests. + */ +private[memory] trait MemoryManagerSuite extends SparkFunSuite { + + import MemoryManagerSuite.DEFAULT_ENSURE_FREE_SPACE_CALLED + + // Note: Mockito's verify mechanism does not provide a way to reset method call counts + // without also resetting stubbed methods. Since our test code relies on the latter, + // we need to use our own variable to track invocations of `ensureFreeSpace`. + + /** + * The amount of free space requested in the last call to [[MemoryStore.ensureFreeSpace]] + * + * This set whenever [[MemoryStore.ensureFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through [[assertEnsureFreeSpaceCalled]]. + */ + private val ensureFreeSpaceCalled = new AtomicLong(DEFAULT_ENSURE_FREE_SPACE_CALLED) + + /** + * Make a mocked [[MemoryStore]] whose [[MemoryStore.ensureFreeSpace]] method is stubbed. + * + * This allows our test code to release storage memory when [[MemoryStore.ensureFreeSpace]] + * is called without relying on [[org.apache.spark.storage.BlockManager]] and all of its + * dependencies. + */ + protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { + val ms = mock(classOf[MemoryStore]) + when(ms.ensureFreeSpace(anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 0)) + when(ms.ensureFreeSpace(any(), anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 1)) + mm.setMemoryStore(ms) + ms + } + + /** + * Make an [[Answer]] that stubs [[MemoryStore.ensureFreeSpace]] with the right arguments. + */ + private def ensureFreeSpaceAnswer(mm: MemoryManager, numBytesPos: Int): Answer[Boolean] = { + new Answer[Boolean] { + override def answer(invocation: InvocationOnMock): Boolean = { + val args = invocation.getArguments + require(args.size > numBytesPos, s"bad test: expected >$numBytesPos arguments " + + s"in ensureFreeSpace, found ${args.size}") + require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + + s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") + val numBytes = args(numBytesPos).asInstanceOf[Long] + mockEnsureFreeSpace(mm, numBytes) + } + } + } + + /** + * Simulate the part of [[MemoryStore.ensureFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[ensureFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEnsureFreeSpaceCalled]]. + */ + private def mockEnsureFreeSpace(mm: MemoryManager, numBytes: Long): Boolean = mm.synchronized { + require(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, + "bad test: ensure free space variable was not reset") + // Record the number of bytes we freed this call + ensureFreeSpaceCalled.set(numBytes) + if (numBytes <= mm.maxStorageMemory) { + def freeMemory = mm.maxStorageMemory - mm.storageMemoryUsed + val spaceToRelease = numBytes - freeMemory + if (spaceToRelease > 0) { + mm.releaseStorageMemory(spaceToRelease) + } + freeMemory >= numBytes + } else { + // We attempted to free more bytes than our max allowable memory + false + } + } + + /** + * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + */ + protected def assertEnsureFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(ensureFreeSpaceCalled.get() === numBytes, + s"expected ensure free space to be called with $numBytes") + ensureFreeSpaceCalled.set(DEFAULT_ENSURE_FREE_SPACE_CALLED) + } + + /** + * Assert that [[MemoryStore.ensureFreeSpace]] is NOT called. + */ + protected def assertEnsureFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, + "ensure free space should not have been called!") + } +} + +private object MemoryManagerSuite { + private val DEFAULT_ENSURE_FREE_SPACE_CALLED = -1L +} diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index c436a8b5c9f81..6cae1f871e24b 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -19,32 +19,44 @@ package org.apache.spark.memory import scala.collection.mutable.ArrayBuffer -import org.mockito.Mockito.{mock, reset, verify, when} -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} -import org.apache.spark.{SparkConf, SparkFunSuite} -class StaticMemoryManagerSuite extends SparkFunSuite { +class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") + private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + + /** + * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. + */ + private def makeThings( + maxExecutionMem: Long, + maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { + val mm = new StaticMemoryManager( + conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem) + val ms = makeMemoryStore(mm) + (mm, ms) + } test("basic execution memory") { val maxExecutionMem = 1000L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) assert(mm.executionMemoryUsed === 0L) - assert(mm.acquireExecutionMemory(10L) === 10L) + assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.acquireExecutionMemory(100L) === 100L) + assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) // Acquire up to the max - assert(mm.acquireExecutionMemory(1000L) === 890L) + assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.acquireExecutionMemory(1L) === 0L) + assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) mm.releaseExecutionMemory(800L) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.acquireExecutionMemory(1L) === 1L) + assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired mm.releaseExecutionMemory(maxExecutionMem) @@ -54,37 +66,36 @@ class StaticMemoryManagerSuite extends SparkFunSuite { test("basic storage memory") { val maxStorageMem = 1000L val dummyBlock = TestBlockId("you can see the world you brought to live") - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, dummyBlock, 10L) + assertEnsureFreeSpaceCalled(ms, 10L) assert(mm.storageMemoryUsed === 10L) - assert(evictedBlocks.isEmpty) assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) + assertEnsureFreeSpaceCalled(ms, 100L) assert(mm.storageMemoryUsed === 110L) - // Acquire up to the max, not granted - assert(!mm.acquireStorageMemory(dummyBlock, 1000L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 1000L) + // Acquire more than the max, not granted + assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, maxStorageMem + 1L) assert(mm.storageMemoryUsed === 110L) - assert(mm.acquireStorageMemory(dummyBlock, 890L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 890L) + // Acquire up to the max, requests after this are still granted due to LRU eviction + assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1000L) assert(mm.storageMemoryUsed === 1000L) - assert(!mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1L) assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assertEnsureFreeSpaceCalled(ms, 1L) assert(mm.storageMemoryUsed === 201L) - mm.releaseStorageMemory() + mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L) + assertEnsureFreeSpaceCalled(ms, 1L) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -95,18 +106,17 @@ class StaticMemoryManagerSuite extends SparkFunSuite { val maxExecutionMem = 200L val maxStorageMem = 1000L val dummyBlock = TestBlockId("ain't nobody love like you do") - val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) // Only execution memory should increase - assert(mm.acquireExecutionMemory(100L) === 100L) + assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.acquireExecutionMemory(1000L) === 100L) + assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase - assert(mm.acquireStorageMemory(dummyBlock, 50L, dummyBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 50L) + assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 50L) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released @@ -114,7 +124,7 @@ class StaticMemoryManagerSuite extends SparkFunSuite { assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 67L) // Only storage memory should be released - mm.releaseStorageMemory() + mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 67L) } @@ -122,51 +132,26 @@ class StaticMemoryManagerSuite extends SparkFunSuite { test("unroll memory") { val maxStorageMem = 1000L val dummyBlock = TestBlockId("lonely water") - val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) - assert(mm.acquireUnrollMemory(dummyBlock, 100L, dummyBlocks)) - assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L) + assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 100L) assert(mm.storageMemoryUsed === 100L) mm.releaseUnrollMemory(40L) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireUnrollMemory(dummyBlock, 500L, dummyBlocks)) + assert(mm.acquireUnrollMemory(dummyBlock, 500L, evictedBlocks)) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. - assertEnsureFreeSpaceCalled(ms, dummyBlock, 340L) + assertEnsureFreeSpaceCalled(ms, 340L) assert(mm.storageMemoryUsed === 560L) when(ms.currentUnrollMemory).thenReturn(560L) - assert(!mm.acquireUnrollMemory(dummyBlock, 800L, dummyBlocks)) + assert(!mm.acquireUnrollMemory(dummyBlock, 800L, evictedBlocks)) assert(mm.storageMemoryUsed === 560L) // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed - assertEnsureFreeSpaceCalled(ms, dummyBlock, 0L) + assertEnsureFreeSpaceCalled(ms, 0L) // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) } - /** - * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. - */ - private def makeThings( - maxExecutionMem: Long, - maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { - val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem) - val ms = mock(classOf[MemoryStore]) - mm.setMemoryStore(ms) - (mm, ms) - } - - /** - * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. - */ - private def assertEnsureFreeSpaceCalled( - ms: MemoryStore, - blockId: BlockId, - numBytes: Long): Unit = { - verify(ms).ensureFreeSpace(meq(blockId), meq(numBytes: java.lang.Long), any()) - reset(ms) - } - } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala new file mode 100644 index 0000000000000..e7baa50dc2cd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} + + +class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { + private val conf = new SparkConf().set("spark.memory.storageFraction", "0.5") + private val dummyBlock = TestBlockId("--") + private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + + /** + * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. + */ + private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { + val mm = new UnifiedMemoryManager(conf, maxMemory) + val ms = makeMemoryStore(mm) + (mm, ms) + } + + private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { + mm invokePrivate PrivateMethod[Long]('storageRegionSize)() + } + + test("storage region size") { + val maxMemory = 1000L + val (mm, _) = makeThings(maxMemory) + val storageFraction = conf.get("spark.memory.storageFraction").toDouble + val expectedStorageRegionSize = maxMemory * storageFraction + val actualStorageRegionSize = getStorageRegionSize(mm) + assert(expectedStorageRegionSize === actualStorageRegionSize) + } + + test("basic execution memory") { + val maxMemory = 1000L + val (mm, _) = makeThings(maxMemory) + assert(mm.executionMemoryUsed === 0L) + assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.executionMemoryUsed === 10L) + assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + // Acquire up to the max + assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.executionMemoryUsed === maxMemory) + assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.executionMemoryUsed === maxMemory) + mm.releaseExecutionMemory(800L) + assert(mm.executionMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.executionMemoryUsed === 201L) + // Release beyond what was acquired + mm.releaseExecutionMemory(maxMemory) + assert(mm.executionMemoryUsed === 0L) + } + + test("basic storage memory") { + val maxMemory = 1000L + val (mm, ms) = makeThings(maxMemory) + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + // `ensureFreeSpace` should be called with the number of bytes requested + assertEnsureFreeSpaceCalled(ms, 10L) + assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 110L) + // Acquire more than the max, not granted + assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, maxMemory + 1L) + assert(mm.storageMemoryUsed === 110L) + // Acquire up to the max, requests after this are still granted due to LRU eviction + assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1000L) + assert(mm.storageMemoryUsed === 1000L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.storageMemoryUsed === 1000L) + mm.releaseStorageMemory(800L) + assert(mm.storageMemoryUsed === 200L) + // Acquire after release + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.storageMemoryUsed === 201L) + mm.releaseAllStorageMemory() + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.storageMemoryUsed === 1L) + // Release beyond what was acquired + mm.releaseStorageMemory(100L) + assert(mm.storageMemoryUsed === 0L) + } + + test("execution evicts storage") { + val maxMemory = 1000L + val (mm, ms) = makeThings(maxMemory) + // First, ensure the test classes are set up as expected + val expectedStorageRegionSize = 500L + val expectedExecutionRegionSize = 500L + val storageRegionSize = getStorageRegionSize(mm) + val executionRegionSize = maxMemory - expectedStorageRegionSize + require(storageRegionSize === expectedStorageRegionSize, + "bad test: storage region size is unexpected") + require(executionRegionSize === expectedExecutionRegionSize, + "bad test: storage region size is unexpected") + // Acquire enough storage memory to exceed the storage region + assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 750L) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 750L) + require(mm.storageMemoryUsed > storageRegionSize, + s"bad test: storage memory used should exceed the storage region") + // Execution needs to request 250 bytes to evict storage memory + assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.executionMemoryUsed === 100L) + assert(mm.storageMemoryUsed === 750L) + assertEnsureFreeSpaceNotCalled(ms) + // Execution wants 200 bytes but only 150 are free, so storage is evicted + assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assertEnsureFreeSpaceCalled(ms, 200L) + assert(mm.executionMemoryUsed === 300L) + mm.releaseAllStorageMemory() + require(mm.executionMemoryUsed < executionRegionSize, + s"bad test: execution memory used should be within the execution region") + require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") + // Acquire some storage memory again, but this time keep it within the storage region + assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) + assertEnsureFreeSpaceCalled(ms, 400L) + require(mm.storageMemoryUsed < storageRegionSize, + s"bad test: storage memory used should be within the storage region") + // Execution cannot evict storage because the latter is within the storage fraction, + // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 + assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.executionMemoryUsed === 600L) + assert(mm.storageMemoryUsed === 400L) + assertEnsureFreeSpaceNotCalled(ms) + } + + test("storage does not evict execution") { + val maxMemory = 1000L + val (mm, ms) = makeThings(maxMemory) + // First, ensure the test classes are set up as expected + val expectedStorageRegionSize = 500L + val expectedExecutionRegionSize = 500L + val storageRegionSize = getStorageRegionSize(mm) + val executionRegionSize = maxMemory - expectedStorageRegionSize + require(storageRegionSize === expectedStorageRegionSize, + "bad test: storage region size is unexpected") + require(executionRegionSize === expectedExecutionRegionSize, + "bad test: storage region size is unexpected") + // Acquire enough execution memory to exceed the execution region + assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 0L) + assertEnsureFreeSpaceNotCalled(ms) + require(mm.executionMemoryUsed > executionRegionSize, + s"bad test: execution memory used should exceed the execution region") + // Storage should not be able to evict execution + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 100L) + assertEnsureFreeSpaceCalled(ms, 100L) + assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) + assert(mm.executionMemoryUsed === 800L) + assert(mm.storageMemoryUsed === 100L) + assertEnsureFreeSpaceCalled(ms, 250L) + mm.releaseExecutionMemory(maxMemory) + mm.releaseStorageMemory(maxMemory) + // Acquire some execution memory again, but this time keep it within the execution region + assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 0L) + assertEnsureFreeSpaceNotCalled(ms) + require(mm.executionMemoryUsed < executionRegionSize, + s"bad test: execution memory used should be within the execution region") + // Storage should still not be able to evict execution + assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 750L) + assertEnsureFreeSpaceCalled(ms, 750L) + assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) + assert(mm.executionMemoryUsed === 200L) + assert(mm.storageMemoryUsed === 750L) + assertEnsureFreeSpaceCalled(ms, 850L) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 6d45b1a101be6..5877aa042d4af 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -24,7 +24,8 @@ import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} +import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.executor.TaskMetrics class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { @@ -37,7 +38,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { try { val taskAttemptId = nextTaskAttemptId.getAndIncrement val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + val taskMetrics = new TaskMetrics when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics) TaskContext.setTaskContext(mockTaskContext) body } finally { diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index 6351539e91e97..259020a2ddc34 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -36,9 +36,6 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { conf.set("spark.shuffle.manager", "tungsten-sort") - // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort - // shuffle records. - conf.set("spark.shuffle.memoryFraction", "0.5") } test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 12e9bafcc92c1..0a03c32c647ae 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec +// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078) + class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) @@ -243,7 +245,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { */ private def testSimpleSpilling(codec: Option[String] = None): Unit = { val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times @@ -291,7 +292,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] @@ -340,7 +340,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) @@ -365,7 +364,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] @@ -382,7 +380,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(loadDefaults = true) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] @@ -401,8 +398,8 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("external aggregation updates peak execution memory") { val conf = createSparkConf(loadDefaults = false) - .set("spark.shuffle.memoryFraction", "0.001") .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + .set("spark.testing.memory", (10 * 1024 * 1024).toString) sc = new SparkContext("local", "test", conf) // No spilling AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index bdb0f4d507a7e..651c7eaa65ff5 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -24,6 +24,8 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078) + class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -38,6 +40,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") + conf.set("spark.testing.memory", "2000000") conf } @@ -50,7 +53,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def emptyDataStream(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -91,7 +93,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def fewElementsPerPartition(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -140,7 +141,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def emptyPartitionsWithSpilling(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -174,7 +174,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def testSpillingInLocalCluster(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -252,7 +251,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) @@ -323,7 +321,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("cleanup of intermediate files in sorter") { val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager @@ -348,7 +345,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("cleanup of intermediate files in sorter if there are errors") { val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager @@ -372,7 +368,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("cleanup of intermediate files in shuffle") { val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager @@ -387,7 +382,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("cleanup of intermediate files in shuffle with errors") { val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager @@ -416,7 +410,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def noPartialAggregationOrSorting(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -438,7 +431,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def partialAggregationWithoutSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -461,7 +453,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def partialAggregationWIthSpillNoOrdering(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -485,7 +476,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def partialAggregationWithSpillWithOrdering(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -512,7 +502,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def sortingWithoutAggregationNoSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -536,7 +525,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def sortingWithoutAggregationWithSpill(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -553,7 +541,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) @@ -610,7 +597,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) @@ -633,7 +619,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) @@ -657,7 +642,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(true, false) - conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) @@ -693,7 +677,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } private def sortWithoutBreakingSortingContracts(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) diff --git a/docs/configuration.md b/docs/configuration.md index 154a3aee6855a..771d93be04b06 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -445,17 +445,6 @@ Apart from these, the following properties are also available, and may be useful met. - - spark.shuffle.memoryFraction - 0.2 - - Fraction of Java heap to use for aggregation and cogroups during shuffles. - At any given time, the collective size of - all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will - begin to spill to disk. If spills are often, consider increasing this value at the expense of - spark.storage.memoryFraction. - - spark.shuffle.service.enabled false @@ -712,6 +701,76 @@ Apart from these, the following properties are also available, and may be useful +#### Memory Management + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.memory.fraction0.75 + Fraction of the heap space used for execution and storage. The lower this is, the more + frequently spills and cached data eviction occur. The purpose of this config is to set + aside memory for internal metadata, user data structures, and imprecise size estimation + in the case of sparse, unusually large records. +
spark.memory.storageFraction0.5 + T​he size of the storage region within the space set aside by + s​park.memory.fraction. This region is not statically reserved, but dynamically + allocated as cache requests come in. ​Cached data may be evicted only if total storage exceeds + this region. +
spark.memory.useLegacyModefalse + ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + The legacy mode rigidly partitions the heap space into fixed-size regions, + potentially leading to excessive spilling if the application was not tuned. + The following deprecated memory fraction configurations are not read unless this is enabled: + spark.shuffle.memoryFraction
+ spark.storage.memoryFraction
+ spark.storage.unrollFraction +
spark.shuffle.memoryFraction0.2 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + spark.storage.memoryFraction. +
spark.storage.memoryFraction0.6 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" + generation of objects in the JVM, which by default is given 0.6 of the heap, but you can + increase it if you configure your own old generation size. +
spark.storage.unrollFraction0.2 + (deprecated) This is read only if spark.memory.useLegacyMode is enabled. + Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory. + This is dynamically allocated by dropping existing blocks when there is not enough free + storage space to unroll the new block in its entirety. +
+ #### Execution Behavior @@ -824,15 +883,6 @@ Apart from these, the following properties are also available, and may be useful This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since data may need to be rewritten to pre-existing output directories during checkpoint recovery. - - - - - @@ -842,15 +892,6 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. - - - - - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index ff65d7bdf8b92..835f52fa566a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -57,7 +57,9 @@ class TestShuffleMemoryManager } private class GrantEverythingMemoryManager extends MemoryManager { - override def acquireExecutionMemory(numBytes: Long): Long = numBytes + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes override def acquireStorageMemory( blockId: BlockId, numBytes: Long, @@ -66,12 +68,6 @@ private class GrantEverythingMemoryManager extends MemoryManager { blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { } - override def releaseStorageMemory(numBytes: Long): Unit = { } - override def releaseStorageMemory(): Unit = { } - override def releaseUnrollMemory(numBytes: Long): Unit = { } override def maxExecutionMemory: Long = Long.MaxValue override def maxStorageMemory: Long = Long.MaxValue - override def executionMemoryUsed: Long = Long.MaxValue - override def storageMemoryUsed: Long = Long.MaxValue } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index f7d48bc53ebbc..75d1fced594c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -103,7 +103,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val conf = new SparkConf() .set("spark.shuffle.spill.initialMemoryThreshold", "1024") .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.shuffle.memoryFraction", "0.0001") + .set("spark.testing.memory", "80000") sc = new SparkContext("local", "test", conf) outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") From 0d1b73b78b600420121ea8e58ff659ae8b4feebe Mon Sep 17 00:00:00 2001 From: trystanleftwich Date: Tue, 13 Oct 2015 22:11:08 +0100 Subject: [PATCH 040/139] =?UTF-8?q?[SPARK-11052]=20Spaces=20in=20the=20bui?= =?UTF-8?q?ld=20dir=20causes=20failures=20in=20the=20build/mv=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …n script Author: trystanleftwich Closes #9065 from trystanleftwich/SPARK-11052. --- build/mvn | 10 +++++----- make-distribution.sh | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/build/mvn b/build/mvn index ec0380afad319..7603ea03deb73 100755 --- a/build/mvn +++ b/build/mvn @@ -104,8 +104,8 @@ install_scala() { "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" - SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar" - SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" + SCALA_COMPILER="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-compiler.jar" + SCALA_LIBRARY="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-library.jar" } # Setup healthy defaults for the Zinc port if none were provided from @@ -135,10 +135,10 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it -if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} - ${ZINC_BIN} -shutdown -port ${ZINC_PORT} - ${ZINC_BIN} -start -port ${ZINC_PORT} \ + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi diff --git a/make-distribution.sh b/make-distribution.sh index 62c0ba6df7d3f..24418ace26270 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -121,7 +121,7 @@ if [ $(command -v git) ]; then fi -if [ ! $(command -v "$MVN") ] ; then +if [ ! "$(command -v "$MVN")" ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; From ef72673b234579c161b8cbb6cafc851d9eba1bfb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 13 Oct 2015 15:09:31 -0700 Subject: [PATCH 041/139] [SPARK-11080] [SQL] Incorporate per-JVM id into ExprId to prevent unsafe cross-JVM comparisions In the current implementation of named expressions' `ExprIds`, we rely on a per-JVM AtomicLong to ensure that expression ids are unique within a JVM. However, these expression ids will not be _globally_ unique. This opens the potential for id collisions if new expression ids happen to be created inside of tasks rather than on the driver. There are currently a few cases where tasks allocate expression ids, which happen to be safe because those expressions are never compared to expressions created on the driver. In order to guard against the introduction of invalid comparisons between driver-created and executor-created expression ids, this patch extends `ExprId` to incorporate a UUID to identify the JVM that created the id, which prevents collisions. Author: Josh Rosen Closes #9093 from JoshRosen/SPARK-11080. --- .../catalyst/expressions/namedExpressions.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 5768c6087db32..8957df0be6814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.UUID + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -24,16 +26,23 @@ import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId: ExprId = ExprId(curId.getAndIncrement()) + private[expressions] val jvmId = UUID.randomUUID() + def newExprId: ExprId = ExprId(curId.getAndIncrement(), jvmId) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } /** - * A globally unique (within this JVM) id for a given named expression. + * A globally unique id for a given named expression. * Used to identify which attribute output by a relation is being * referenced in a subsequent computation. + * + * The `id` field is unique within a given JVM, while the `uuid` is used to uniquely identify JVMs. */ -case class ExprId(id: Long) +case class ExprId(id: Long, jvmId: UUID) + +object ExprId { + def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) +} /** * An [[Expression]] that is named. From d0482f6af33e976db237405b2a978db1b7c2fd5b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 13 Oct 2015 15:18:20 -0700 Subject: [PATCH 042/139] [SPARK-10932] [PROJECT INFRA] Port two minor changes to release-build.sh from scripts' old repo Spark's release packaging scripts used to live in a separate repository. Although these scripts are now part of the Spark repo, there are some minor patches made against the old repos that are missing in Spark's copy of the script. This PR ports those changes. /cc shivaram, who originally submitted these changes against https://github.com/rxin/spark-utils Author: Josh Rosen Closes #8986 from JoshRosen/port-release-build-fixes-from-rxin-repo. --- dev/create-release/release-build.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 9dac43ce54425..cb79e9eba06e2 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -70,7 +70,7 @@ GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" GPG="gpg --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads @@ -141,8 +141,12 @@ if [[ "$1" == "package" ]]; then export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log + + # Get maven home set by MVN + MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` + + ./make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . From 3889b1c7a96da1111946fa63ad69489b83468646 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Tue, 13 Oct 2015 15:57:36 -0700 Subject: [PATCH 043/139] [SPARK-11059] [ML] Change range of quantile probabilities in AFTSurvivalRegression Value of the quantile probabilities array should be in the range (0, 1) instead of [0,1] in `AFTSurvivalRegression.scala` according to [Discussion] (https://github.com/apache/spark/pull/8926#discussion-diff-40698242) Author: vectorijk Closes #9083 from vectorijk/spark-11059. --- .../apache/spark/ml/regression/AFTSurvivalRegression.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 717caacad30eb..ac2c3d825f13c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -59,14 +59,14 @@ private[regression] trait AFTSurvivalRegressionParams extends Params /** * Param for quantile probabilities array. - * Values of the quantile probabilities array should be in the range [0, 1] + * Values of the quantile probabilities array should be in the range (0, 1) * and the array should be non-empty. * @group param */ @Since("1.6.0") final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, "quantileProbabilities", "quantile probabilities array", - (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)) && t.length > 0) + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1, false, false)) && t.length > 0) /** @group getParam */ @Since("1.6.0") From 328d1b3e4bc39cce653342e04f9e08af12dd7ed8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 13 Oct 2015 17:09:17 -0700 Subject: [PATCH 044/139] [SPARK-11090] [SQL] Constructor for Product types from InternalRow This is a first draft of the ability to construct expressions that will take a catalyst internal row and construct a Product (case class or tuple) that has fields with the correct names. Support include: - Nested classes - Maps - Efficiently handling of arrays of primitive types Not yet supported: - Case classes that require custom collection types (i.e. List instead of Seq). Author: Michael Armbrust Closes #9100 from marmbrus/productContructor. --- .../catalyst/expressions/UnsafeArrayData.java | 4 + .../spark/sql/catalyst/ScalaReflection.scala | 302 +++++++++++++- .../spark/sql/catalyst/encoders/Encoder.scala | 14 + .../catalyst/encoders/ProductEncoder.scala | 26 +- .../sql/catalyst/expressions/objects.scala | 154 +++++++- .../spark/sql/types/ArrayBasedMapData.scala | 4 + .../apache/spark/sql/types/ArrayData.scala | 5 + .../spark/sql/types/GenericArrayData.scala | 4 +- .../encoders/ProductEncoderSuite.scala | 369 +++++++++++------- 9 files changed, 723 insertions(+), 159 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 796f8abec9a1d..4c63abb071e3b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -74,6 +74,10 @@ private void assertIndexIsValid(int ordinal) { assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; } + public Object[] array() { + throw new UnsupportedOperationException("Only supported on GenericArrayData."); + } + /** * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until * `pointTo()` has been called, since the value returned by this constructor is equivalent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8b733f2a0b91f..8edd6498e5163 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -80,6 +81,9 @@ trait ScalaReflection { * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers */ def dataTypeFor(tpe: `Type`): DataType = tpe match { case t if t <:< definitions.IntTpe => IntegerType @@ -114,6 +118,298 @@ trait ScalaReflection { } } + /** + * Given a type `T` this function constructs and ObjectType that holds a class of type + * Array[T]. Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ + def arrayClassFor(tpe: `Type`): DataType = { + val cls = tpe match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case other => + // There is probably a better way to do this, but I couldn't find it... + val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls + java.lang.reflect.Array.newInstance(elementType, 1).getClass + + } + ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of type `T` given a an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + + protected def constructorFor( + tpe: `Type`, + path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String) = + path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or throws an error. */ + def getPath = path.getOrElse(sys.error("Constructors must start at a class type")) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => + getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val boxedType = optType match { + // For primitive types we must manually box the primitive value. + case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) + case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) + case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) + case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) + case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) + case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) + case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) + case _ => None + } + + boxedType.map { boxedType => + val objectType = ObjectType(boxedType) + WrapOption( + objectType, + NewInstance( + boxedType, + getPath :: Nil, + propagateNull = true, + objectType)) + }.getOrElse { + val className: String = optType.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + val objectType = ObjectType(cls) + + WrapOption(objectType, constructorFor(optType, path)) + } + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, dataTypeFor(t)) + }.getOrElse { + val returnType = dataTypeFor(t) + Invoke( + MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + "array", + returnType) + } + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + + val primitiveMethodKey = keyType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val primitiveMethodValue = valueType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + // Avoid boxing when possible by just wrapping a primitive array. + val primitiveMethod = elementType match { + case _ if nullable => None + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val arrayData = primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + "array", + arrayClassFor(elementType)) + } + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.map { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = dataTypeFor(fieldType) + + constructorFor(fieldType, Some(addToPath(fieldName))) + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + + } + } + /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { ScalaReflectionLock.synchronized { @@ -227,13 +523,13 @@ trait ScalaReflection { val elementDataType = dataTypeFor(elementType) val Schema(dataType, nullable) = schemaFor(elementType) - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { + if (dataType.isInstanceOf[AtomicType]) { NewInstance( classOf[GenericArrayData], inputObject :: Nil, dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) } case t if t <:< localTypeOf[Map[_, _]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 8dacfa9477ee6..3618247d5d51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.encoders + import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -41,4 +43,16 @@ trait Encoder[T] { * copy the result before making another call if required. */ def toRow(t: T): InternalRow + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must bind` and encoder to a specific schema before you can call this function. + */ + def fromRow(row: InternalRow): T + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the + * given schema + */ + def bind(schema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index a23613673ebb5..b0381880c3bdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.encoders +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} @@ -31,7 +33,7 @@ import org.apache.spark.sql.types.{ObjectType, StructType} * internal binary representation. */ object ProductEncoder { - def apply[T <: Product : TypeTag]: Encoder[T] = { + def apply[T <: Product : TypeTag]: ClassEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] val mirror = typeTag[T].mirror @@ -39,7 +41,8 @@ object ProductEncoder { val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) - new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls)) + val constructExpression = ScalaReflection.constructorFor[T] + new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) } } @@ -54,14 +57,31 @@ object ProductEncoder { case class ClassEncoder[T]( schema: StructType, extractExpressions: Seq[Expression], + constructExpression: Expression, clsTag: ClassTag[T]) extends Encoder[T] { private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) private val inputRow = new GenericMutableRow(1) + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private val dataType = ObjectType(clsTag.runtimeClass) + override def toRow(t: T): InternalRow = { inputRow(0) = t extractProjection(inputRow) } + + override def fromRow(row: InternalRow): T = { + constructProjection(row).get(0, dataType).asInstanceOf[T] + } + + override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { + val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + val resolvedExpression = analyzedPlan.expressions.head.children.head + val boundExpression = BindReferences.bindReference(resolvedExpression, schema) + + copy(constructExpression = boundExpression) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e1f960a6e605c..e8c1c93cf5620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} + import scala.language.existentials -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -48,7 +51,7 @@ case class StaticInvoke( case other => other.getClass.getName.stripSuffix("$") } override def nullable: Boolean = true - override def children: Seq[Expression] = Nil + override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -69,7 +72,7 @@ case class StaticInvoke( s""" ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; + boolean ${ev.isNull} = !$argsNonNull; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; if ($argsNonNull) { @@ -81,8 +84,8 @@ case class StaticInvoke( s""" ${argGen.map(_.code).mkString("\n")} - final boolean ${ev.isNull} = ${ev.value} == null; $javaType ${ev.value} = $objectName.$functionName($argString); + final boolean ${ev.isNull} = ${ev.value} == null; """ } } @@ -92,6 +95,10 @@ case class StaticInvoke( * Calls the specified function on an object, optionally passing arguments. If the `targetObject` * expression evaluates to null then null will be returned. * + * In some cases, due to erasure, the schema may expect a primitive type when in fact the method + * is returning java.lang.Object. In this case, we will generate code that attempts to unbox the + * value automatically. + * * @param targetObject An expression that will return the object to call the method on. * @param functionName The name of the method to call. * @param dataType The expected return type of the function. @@ -109,6 +116,35 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val method = targetObject.dataType match { + case ObjectType(cls) => + cls + .getMethods + .find(_.getName == functionName) + .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) + .getReturnType + .getName + case _ => "" + } + + lazy val unboxer = (dataType, method) match { + case (IntegerType, "java.lang.Object") => (s: String) => + s"((java.lang.Integer)$s).intValue()" + case (LongType, "java.lang.Object") => (s: String) => + s"((java.lang.Long)$s).longValue()" + case (FloatType, "java.lang.Object") => (s: String) => + s"((java.lang.Float)$s).floatValue()" + case (ShortType, "java.lang.Object") => (s: String) => + s"((java.lang.Short)$s).shortValue()" + case (ByteType, "java.lang.Object") => (s: String) => + s"((java.lang.Byte)$s).byteValue()" + case (DoubleType, "java.lang.Object") => (s: String) => + s"((java.lang.Double)$s).doubleValue()" + case (BooleanType, "java.lang.Object") => (s: String) => + s"((java.lang.Boolean)$s).booleanValue()" + case _ => identity[String] _ + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val obj = targetObject.gen(ctx) @@ -123,6 +159,8 @@ case class Invoke( "" } + val value = unboxer(s"${obj.value}.$functionName($argString)") + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} @@ -130,7 +168,7 @@ case class Invoke( boolean ${ev.isNull} = ${obj.value} == null; $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString); + ${ctx.defaultValue(dataType)} : ($javaType) $value; $objNullCheck """ } @@ -190,8 +228,8 @@ case class NewInstance( s""" ${argGen.map(_.code).mkString("\n")} - final boolean ${ev.isNull} = ${ev.value} == null; $javaType ${ev.value} = new $className($argString); + final boolean ${ev.isNull} = ${ev.value} == null; """ } } @@ -210,8 +248,6 @@ case class UnwrapOption( override def nullable: Boolean = true - override def children: Seq[Expression] = Nil - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil override def eval(input: InternalRow): Any = @@ -231,6 +267,43 @@ case class UnwrapOption( } } +/** + * Converts the result of evaluating `child` into an option, checking both the isNull bit and + * (in the case of reference types) equality with null. + * @param optionType The datatype to be held inside of the Option. + * @param child The expression to evaluate and wrap. + */ +case class WrapOption(optionType: DataType, child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = ObjectType(classOf[Option[_]]) + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(optionType) + val inputObject = child.gen(ctx) + + s""" + ${inputObject.code} + + boolean ${ev.isNull} = false; + scala.Option<$javaType> ${ev.value} = + ${inputObject.isNull} ? + scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); + """ + } +} + +/** + * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * manually, but will instead be passed into the provided lambda function. + */ case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = @@ -251,7 +324,7 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array + * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -265,14 +338,32 @@ case class MapObjects( inputData: Expression, elementType: DataType) extends Expression { - private val loopAttribute = AttributeReference("loopVar", elementType)() - private val completeFunction = function(loopAttribute) + private lazy val loopAttribute = AttributeReference("loopVar", elementType)() + private lazy val completeFunction = function(loopAttribute) - private val (lengthFunction, itemAccessor) = inputData.dataType match { - case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) => - (".size()", (i: String) => s".apply($i)") + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => - (".length", (i: String) => s"[$i]") + (".length", (i: String) => s"[$i]", false) + case ArrayType(s: StructType, _) => + (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) + case ArrayType(a: ArrayType, _) => + (".numElements()", (i: String) => s".getArray($i)", true) + case ArrayType(IntegerType, _) => + (".numElements()", (i: String) => s".getInt($i)", true) + case ArrayType(LongType, _) => + (".numElements()", (i: String) => s".getLong($i)", true) + case ArrayType(FloatType, _) => + (".numElements()", (i: String) => s".getFloat($i)", true) + case ArrayType(DoubleType, _) => + (".numElements()", (i: String) => s".getDouble($i)", true) + case ArrayType(ByteType, _) => + (".numElements()", (i: String) => s".getByte($i)", true) + case ArrayType(ShortType, _) => + (".numElements()", (i: String) => s".getShort($i)", true) + case ArrayType(BooleanType, _) => + (".numElements()", (i: String) => s".getBoolean($i)", true) } override def nullable: Boolean = true @@ -294,15 +385,38 @@ case class MapObjects( val loopIsNull = ctx.freshName("loopIsNull") val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) - val boundFunction = completeFunction transform { + val substitutedFunction = completeFunction transform { case a: AttributeReference if a == loopAttribute => loopVariable } + // A hack to run this through the analyzer (to bind extractions). + val boundFunction = + SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil))) + .expressions.head.children.head val genFunction = boundFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") + val convertedType = ctx.javaType(boundFunction.dataType) + + // Because of the way Java defines nested arrays, we have to handle the syntax specially. + // Specifically, we have to insert the [$dataLength] in between the type and any extra nested + // array declarations (i.e. new String[1][]). + val arrayConstructor = if (convertedType contains "[]") { + val rawType = convertedType.takeWhile(_ != '[') + val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse + s"new $rawType[$dataLength]$arrayPart" + } else { + s"new $convertedType[$dataLength]" + } + + val loopNullCheck = if (primitiveElement) { + s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + } else { + s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;" + } + s""" ${genInputData.code} @@ -310,19 +424,19 @@ case class MapObjects( $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - Object[] $convertedArray = null; + $convertedType[] $convertedArray = null; int $dataLength = ${genInputData.value}$lengthFunction; - $convertedArray = new Object[$dataLength]; + $convertedArray = $arrayConstructor; int $loopIndex = 0; while ($loopIndex < $dataLength) { $elementJavaType $loopValue = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; - boolean $loopIsNull = $loopValue == null; + $loopNullCheck ${genFunction.code} - $convertedArray[$loopIndex] = ${genFunction.value}; + $convertedArray[$loopIndex] = ($convertedType)${genFunction.value}; $loopIndex += 1; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index 52069598ee30e..5f22e59d5f1d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -62,4 +62,8 @@ object ArrayBasedMapData { val values = map.valueArray.asInstanceOf[GenericArrayData].array keys.zip(values).toMap } + + def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = { + keys.zip(values).toMap + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala index 642c56f12ded1..b4ea300f5f306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -26,6 +26,8 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def copy(): ArrayData + def array: Array[Any] + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) @@ -103,6 +105,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { values } + def toObjectArray(elementType: DataType): Array[AnyRef] = + toArray[AnyRef](elementType: DataType) + def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() val values = new Array[T](size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index c3816033275d5..9448d88d6c5f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { +class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) @@ -29,6 +29,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) override def copy(): ArrayData = new GenericArrayData(array.clone()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index 99c993d3febc2..02e43ddb35478 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -17,158 +17,263 @@ package org.apache.spark.sql.catalyst.encoders -import java.sql.{Date, Timestamp} +import java.util + +import org.apache.spark.sql.types.{StructField, ArrayType, ArrayData} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst._ - case class RepeatedStruct(s: Seq[PrimitiveData]) case class NestedArray(a: Array[Array[Int]]) -class ProductEncoderSuite extends SparkFunSuite { +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) - test("convert PrimitiveData to InternalRow") { - val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val encoder = ProductEncoder[PrimitiveData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 1) - assert(convertedData.getLong(1) == 1.toLong) - assert(convertedData.getDouble(2) == 1.toDouble) - assert(convertedData.getFloat(3) == 1.toFloat) - assert(convertedData.getShort(4) == 1.toShort) - assert(convertedData.getByte(5) == 1.toByte) - assert(convertedData.getBoolean(6) == true) - } +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) - test("convert Some[_] to InternalRow") { - val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(primitiveData)) - - val encoder = ProductEncoder[OptionalData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 2) - assert(convertedData.getLong(1) == 2.toLong) - assert(convertedData.getDouble(2) == 2.toDouble) - assert(convertedData.getFloat(3) == 2.toFloat) - assert(convertedData.getShort(4) == 2.toShort) - assert(convertedData.getByte(5) == 2.toByte) - assert(convertedData.getBoolean(6) == true) - - val nestedRow = convertedData.getStruct(7, 7) - assert(nestedRow.getInt(0) == 1) - assert(nestedRow.getLong(1) == 1.toLong) - assert(nestedRow.getDouble(2) == 1.toDouble) - assert(nestedRow.getFloat(3) == 1.toFloat) - assert(nestedRow.getShort(4) == 1.toShort) - assert(nestedRow.getByte(5) == 1.toByte) - assert(nestedRow.getBoolean(6) == true) - } +case class SpecificCollection(l: List[Int]) - test("convert None to InternalRow") { - val inputData = OptionalData(None, None, None, None, None, None, None, None) - val encoder = ProductEncoder[OptionalData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.isNullAt(0)) - assert(convertedData.isNullAt(1)) - assert(convertedData.isNullAt(2)) - assert(convertedData.isNullAt(3)) - assert(convertedData.isNullAt(4)) - assert(convertedData.isNullAt(5)) - assert(convertedData.isNullAt(6)) - assert(convertedData.isNullAt(7)) - } +class ProductEncoderSuite extends SparkFunSuite { - test("convert nullable but present data to InternalRow") { - val inputData = NullableData( - 1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0), - new Timestamp(0), Array[Byte](1, 2, 3)) - - val encoder = ProductEncoder[NullableData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 1) - assert(convertedData.getLong(1) == 1.toLong) - assert(convertedData.getDouble(2) == 1.toDouble) - assert(convertedData.getFloat(3) == 1.toFloat) - assert(convertedData.getShort(4) == 1.toShort) - assert(convertedData.getByte(5) == 1.toByte) - assert(convertedData.getBoolean(6) == true) - } + encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - test("convert nullable data to InternalRow") { - val inputData = - NullableData(null, null, null, null, null, null, null, null, null, null, null, null) - - val encoder = ProductEncoder[NullableData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.isNullAt(0)) - assert(convertedData.isNullAt(1)) - assert(convertedData.isNullAt(2)) - assert(convertedData.isNullAt(3)) - assert(convertedData.isNullAt(4)) - assert(convertedData.isNullAt(5)) - assert(convertedData.isNullAt(6)) - assert(convertedData.isNullAt(7)) - assert(convertedData.isNullAt(8)) - assert(convertedData.isNullAt(9)) - assert(convertedData.isNullAt(10)) - assert(convertedData.isNullAt(11)) - } + // TODO: Support creating specific subclasses of Seq. + ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } - test("convert repeated struct") { - val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil) - val encoder = ProductEncoder[RepeatedStruct] - - val converted = encoder.toRow(inputData) - val convertedStruct = converted.getArray(0).getStruct(0, 7) - assert(convertedStruct.getInt(0) == 1) - assert(convertedStruct.getLong(1) == 1.toLong) - assert(convertedStruct.getDouble(2) == 1.toDouble) - assert(convertedStruct.getFloat(3) == 1.toFloat) - assert(convertedStruct.getShort(4) == 1.toShort) - assert(convertedStruct.getByte(5) == 1.toByte) - assert(convertedStruct.getBoolean(6) == true) - } + encodeDecodeTest( + OptionalData( + Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - test("convert nested seq") { - val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1)))) - assert(convertedData.getArray(0).getArray(0).getInt(0) == 1) + encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) - val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1))))) - assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1) - } + encodeDecodeTest( + BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - test("convert nested array") { - val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1)))) - } + encodeDecodeTest( + BoxedData(null, null, null, null, null, null, null)) + + encodeDecodeTest( + RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - test("convert complex") { - val inputData = ComplexData( + encodeDecodeTest( + RepeatedData( Seq(1, 2), - Array(1, 2), - 1 :: 2 :: Nil, Seq(new Integer(1), null, new Integer(2)), Map(1 -> 2L), - Map(1 -> new java.lang.Long(2)), - PrimitiveData(1, 1, 1, 1, 1, 1, true), - Array(Array(1))) - - val encoder = ProductEncoder[ComplexData] - val convertedData = encoder.toRow(inputData) - - assert(!convertedData.isNullAt(0)) - val seq = convertedData.getArray(0) - assert(seq.numElements() == 2) - assert(seq.getInt(0) == 1) - assert(seq.getInt(1) == 2) + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) + + encodeDecodeTest(("Seq[(String, String)]", + Seq(("a", "b")))) + encodeDecodeTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + encodeDecodeTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + encodeDecodeTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + encodeDecodeTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + encodeDecodeTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + encodeDecodeTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + encodeDecodeTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + // TODO: Decoding/encoding of complex maps. + ignore("complex maps") { + encodeDecodeTest(("Map[Int, (String, String)]", + Map(1 ->("a", "b")))) + } + + encodeDecodeTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + encodeDecodeTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + encodeDecodeTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + encodeDecodeTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + encodeDecodeTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + encodeDecodeTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + encodeDecodeTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", + Array(Array((1, 2))))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", + Array(Array(Array((1, 2)))))) + { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", + Array(Array(Array(Array((1, 2))))))) + { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", + Array(Array(Array(Array(Array((1, 2)))))))) + { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } + + + encodeDecodeTestCustom(("Array[Array[Integer]]", + Array(Array[Integer](1)))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Int]]", + Array(Array(1)))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Int]]", + Array(Array(Array(1))))) + { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Int]]]", + Array(Array(Array(Array(1)))))) + { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", + Array(Array(Array(Array(Array(1))))))) + { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } + + encodeDecodeTest(("Array[Byte] null", + null: Array[Byte])) + encodeDecodeTestCustom(("Array[Byte]", + Array[Byte](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Int] null", + null: Array[Int])) + encodeDecodeTestCustom(("Array[Int]", + Array[Int](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Long] null", + null: Array[Long])) + encodeDecodeTestCustom(("Array[Long]", + Array[Long](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Double] null", + null: Array[Double])) + encodeDecodeTestCustom(("Array[Double]", + Array[Double](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Float] null", + null: Array[Float])) + encodeDecodeTestCustom(("Array[Float]", + Array[Float](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Boolean] null", + null: Array[Boolean])) + encodeDecodeTestCustom(("Array[Boolean]", + Array[Boolean](true, false))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Short] null", + null: Array[Short])) + encodeDecodeTestCustom(("Array[Short]", + Array[Short](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTestCustom(("java.sql.Timestamp", + new java.sql.Timestamp(1))) + { (l, r) => l._2.toString == r._2.toString } + + encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) + { (l, r) => l._2.toString == r._2.toString } + + /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ + protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) = + encodeDecodeTestCustom[T](inputData)((l, r) => l == r) + + /** + * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it + * matches the original. + */ + protected def encodeDecodeTestCustom[T <: Product : TypeTag]( + inputData: T)( + c: (T, T) => Boolean) = { + test(s"encode/decode: $inputData") { + val encoder = try ProductEncoder[T] catch { + case e: Exception => + fail(s"Exception thrown generating encoder", e) + } + val convertedData = encoder.toRow(inputData) + val schema = encoder.schema.toAttributes + val boundEncoder = encoder.bind(schema) + val convertedBack = try boundEncoder.fromRow(convertedData) catch { + case e: Exception => + fail( + s"""Exception thrown while decoding + |Converted: $convertedData + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |Construct Expressions: + |${boundEncoder.constructExpression.treeString} + | + """.stripMargin, e) + } + + if (!c(inputData, convertedBack)) { + val types = + convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + + val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map { + case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => + a.toArray[Any](at.elementType).toSeq + case (other, _) => + other + }.mkString("[", ",", "]") + + fail( + s"""Encoded/Decoded data does not match input data + | + |in: $inputData + |out: $convertedBack + |types: $types + | + |Encoded Data: $encodedData + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |Extract Expressions: + |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")} + | + |Construct Expressions: + |${boundEncoder.constructExpression.treeString} + | + """.stripMargin) + } + } } } From e170c22160bb452f98c340489ebf8390116a8cbb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 13 Oct 2015 17:11:22 -0700 Subject: [PATCH 045/139] [SPARK-11032] [SQL] correctly handle having We should not stop resolving having when the having condtion is resolved, or something like `count(1)` will crash. Author: Wenchen Fan Closes #9105 from cloud-fan/having. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f5597a08d3595..041ab22827399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -553,7 +553,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved && !filter.resolved => + if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause val aggregatedCondition = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eca6f1073889a..636591630e136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1809,4 +1809,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df1.withColumn("diff", lit(0))) } } + + test("SPARK-11032: resolve having correctly") { + withTempTable("src") { + Seq(1 -> "a").toDF("i", "j").registerTempTable("src") + checkAnswer( + sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"), + Row(1)) + } + } } From 15ff85b3163acbe8052d4489a00bcf1d2332fcf0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 13 Oct 2015 17:59:32 -0700 Subject: [PATCH 046/139] [SPARK-11068] [SQL] add callback to query execution With this feature, we can track the query plan, time cost, exception during query execution for spark users. Author: Wenchen Fan Closes #9078 from cloud-fan/callback. --- .../org/apache/spark/sql/DataFrame.scala | 46 +++++- .../spark/sql/QueryExecutionListener.scala | 136 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 3 + .../spark/sql/DataFrameCallbackSuite.scala | 82 +++++++++++ 4 files changed, 261 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 01f60aba87ede..bfe8d3c8ef957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1344,7 +1344,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def head(n: Int): Array[Row] = limit(n).collect() + def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df => + df.collect(needCallback = false) + } /** * Returns the first row. @@ -1414,8 +1416,18 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collect(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() + def collect(): Array[Row] = collect(needCallback = true) + + private def collect(needCallback: Boolean): Array[Row] = { + def execute(): Array[Row] = withNewExecutionId { + queryExecution.executedPlan.executeCollectPublic() + } + + if (needCallback) { + withCallback("collect", this)(_ => execute()) + } else { + execute() + } } /** @@ -1423,8 +1435,10 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collectAsList(): java.util.List[Row] = withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } } /** @@ -1432,7 +1446,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def count(): Long = groupBy().count().collect().head.getLong(0) + def count(): Long = withCallback("count", groupBy().count()) { df => + df.collect(needCallback = false).head.getLong(0) + } /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. @@ -1936,6 +1952,24 @@ class DataFrame private[sql]( SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) } + /** + * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the + * user-registered callback functions. + */ + private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { + try { + val start = System.nanoTime() + val result = action(df) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, df.queryExecution, e) + throw e + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // End of deprecated methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala new file mode 100644 index 0000000000000..14fbebb45f8b7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.mutable.ListBuffer + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Logging +import org.apache.spark.sql.execution.QueryExecution + + +/** + * The interface of query execution listener that can be used to analyze execution metrics. + * + * Note that implementations should guarantee thread-safety as they will be used in a non + * thread-safe way. + */ +@Experimental +trait QueryExecutionListener { + + /** + * A callback function that will be called when a query executed successfully. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param duration the execution time for this query in nanoseconds. + */ + @DeveloperApi + def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + + /** + * A callback function that will be called when a query execution failed. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param exception the exception that failed this query. + */ + @DeveloperApi + def onFailure(funcName: String, qe: QueryExecution, exception: Exception) +} + +@Experimental +class ExecutionListenerManager extends Logging { + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + private[this] val lock = new ReentrantReadWriteLock() + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } + + /** + * Registers the specified QueryExecutionListener. + */ + @DeveloperApi + def register(listener: QueryExecutionListener): Unit = writeLock { + listeners += listener + } + + /** + * Unregisters the specified QueryExecutionListener. + */ + @DeveloperApi + def unregister(listener: QueryExecutionListener): Unit = writeLock { + listeners -= listener + } + + /** + * clears out all registered QueryExecutionListeners. + */ + @DeveloperApi + def clear(): Unit = writeLock { + listeners.clear() + } + + private[sql] def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long): Unit = readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } + } + + private[sql] def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } + } + + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { + for (listener <- listeners) { + try { + f(listener) + } catch { + case e: Exception => logWarning("error executing query execution listener", e) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd937257d31a8..a835408f8af3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -177,6 +177,9 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + @transient + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala new file mode 100644 index 0000000000000..4e286a0076205 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +import scala.collection.mutable.ArrayBuffer + +class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import functions._ + + test("execute callback functions when a DataFrame action finished successfully") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += ((funcName, qe, duration)) + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j") + df.select("i").collect() + df.filter($"i" > 0).count() + + assert(metrics.length == 2) + + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3 > 0) + + assert(metrics(1)._1 == "count") + assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) + assert(metrics(1)._3 > 0) + } + + test("execute callback functions when a DataFrame action failed") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + metrics += ((funcName, qe, exception)) + } + + // Only test failed case here, so no need to implement `onSuccess` + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + } + sqlContext.listenerManager.register(listener) + + val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } + val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") + + // Ignore the log when we are expecting an exception. + sparkContext.setLogLevel("FATAL") + val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + + assert(metrics.length == 1) + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3.getMessage == e.getMessage) + } +} From ce3f9a80657751ee0bc0ed6a9b6558acbb40af4f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 13 Oct 2015 18:21:24 -0700 Subject: [PATCH 047/139] [SPARK-11091] [SQL] Change spark.sql.canonicalizeView to spark.sql.nativeView. https://issues.apache.org/jira/browse/SPARK-11091 Author: Yin Huai Closes #9103 from yhuai/SPARK-11091. --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 4 ++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 +++++++------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f62df9bdebcc0..b08cc8e830737 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -328,7 +328,7 @@ private[spark] object SQLConf { doc = "When true, some predicates will be pushed down into the Hive metastore so that " + "unmatching partitions can be eliminated earlier.") - val CANONICALIZE_VIEW = booleanConf("spark.sql.canonicalizeView", + val NATIVE_VIEW = booleanConf("spark.sql.nativeView", defaultValue = Some(false), doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + "Note that this function is experimental and should ony be used when you are using " + @@ -489,7 +489,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def canonicalizeView: Boolean = getConf(CANONICALIZE_VIEW) + private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cf59bc0d590b0..1f8223e1ff507 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -591,7 +591,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if p.resolved => p case CreateViewAsSelect(table, child, allowExisting, replace, sql) => - if (conf.canonicalizeView) { + if (conf.nativeView) { if (allowExisting && replace) { throw new AnalysisException( "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 250c232856885..1d505019400bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -537,7 +537,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C serde = None, viewText = Some(originalText)) - // We need to keep the original SQL string so that if `spark.sql.canonicalizeView` is + // We need to keep the original SQL string so that if `spark.sql.nativeView` is // false, we can fall back to use hive native command later. // We can remove this when parser is configurable(can access SQLConf) in the future. val sql = context.getTokenRewriteStream diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 51b63f3688783..6aa34605b05a8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1282,7 +1282,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("correctly parse CREATE VIEW statement") { - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt") { val df = (1 until 10).map(i => i -> i).toDF("i", "j") df.write.format("json").saveAsTable("jt") @@ -1299,7 +1299,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("correctly handle CREATE VIEW IF NOT EXISTS") { - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt", "jt2") { sqlContext.range(1, 10).write.format("json").saveAsTable("jt") sql("CREATE VIEW testView AS SELECT id FROM jt") @@ -1316,7 +1316,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("correctly handle CREATE OR REPLACE VIEW") { - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt", "jt2") { sqlContext.range(1, 10).write.format("json").saveAsTable("jt") sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") @@ -1339,7 +1339,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("correctly handle ALTER VIEW") { - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt", "jt2") { sqlContext.range(1, 10).write.format("json").saveAsTable("jt") sql("CREATE VIEW testView AS SELECT id FROM jt") @@ -1357,7 +1357,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("create hive view for json table") { // json table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt") { sqlContext.range(1, 10).write.format("json").saveAsTable("jt") sql("CREATE VIEW testView AS SELECT id FROM jt") @@ -1369,7 +1369,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("create hive view for partitioned parquet table") { // partitioned parquet table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("parTable") { val df = Seq(1 -> "a").toDF("i", "j") df.write.format("parquet").partitionBy("i").saveAsTable("parTable") @@ -1382,7 +1382,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("create hive view for joined tables") { // make sure the new flag can handle some complex cases like join and schema change. - withSQLConf(SQLConf.CANONICALIZE_VIEW.key -> "true") { + withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt1", "jt2") { sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") From 8b32885704502ab2a715cf5142d7517181074428 Mon Sep 17 00:00:00 2001 From: Monica Liu Date: Tue, 13 Oct 2015 22:24:52 -0700 Subject: [PATCH 048/139] [SPARK-10981] [SPARKR] SparkR Join improvements I was having issues with collect() and orderBy() in Spark 1.5.0 so I used the DataFrame.R file and test_sparkSQL.R file from the Spark 1.5.1 download. I only modified the join() function in DataFrame.R to include "full", "fullouter", "left", "right", and "leftsemi" and added corresponding test cases in the test for join() and merge() in test_sparkSQL.R file. Pull request because I filed this JIRA bug report: https://issues.apache.org/jira/browse/SPARK-10981 Author: Monica Liu Closes #9029 from mfliu/master. --- R/pkg/R/DataFrame.R | 13 +++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 27 +++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e0ce056243585..b7f5f978ebc2c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1414,9 +1414,10 @@ setMethod("where", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a -#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' Column expression. If joinExpr is omitted, join() will perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', +#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. #' @rdname join #' @name join @@ -1441,11 +1442,15 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + if (joinType %in% c("inner", "outer", "full", "fullouter", + "leftouter", "left_outer", "left", + "rightouter", "right_outer", "right", "leftsemi")) { + joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', + 'rightouter', 'right_outer', 'right', 'leftsemi'") } } } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d5509e475de05..46cab7646dcf9 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1071,7 +1071,7 @@ test_that("join() and merge() on a DataFrame", { expect_equal(names(joined2), c("age", "name", "name", "test")) expect_equal(count(joined2), 3) - joined3 <- join(df, df2, df$name == df2$name, "right_outer") + joined3 <- join(df, df2, df$name == df2$name, "rightouter") expect_equal(names(joined3), c("age", "name", "name", "test")) expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) @@ -1082,11 +1082,34 @@ test_that("join() and merge() on a DataFrame", { expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + joined5 <- join(df, df2, df$name == df2$name, "leftouter") + expect_equal(names(joined5), c("age", "name", "name", "test")) + expect_equal(count(joined5), 3) + expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1])) + + joined6 <- join(df, df2, df$name == df2$name, "inner") + expect_equal(names(joined6), c("age", "name", "name", "test")) + expect_equal(count(joined6), 3) + + joined7 <- join(df, df2, df$name == df2$name, "leftsemi") + expect_equal(names(joined7), c("age", "name")) + expect_equal(count(joined7), 3) + + joined8 <- join(df, df2, df$name == df2$name, "left_outer") + expect_equal(names(joined8), c("age", "name", "name", "test")) + expect_equal(count(joined8), 3) + expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1])) + + joined9 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined9), c("age", "name", "name", "test")) + expect_equal(count(joined9), 4) + expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) + merged <- select(merge(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(merged), c("newAge", "name", "test")) expect_equal(count(merged), 4) - expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) + expect_equal(collect(orderBy(merged, merged$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { From 390b22fad69a33eb6daee25b6b858a2e768670a5 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 13 Oct 2015 22:31:23 -0700 Subject: [PATCH 049/139] [SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions. Author: Sun Rui Closes #9023 from sun-rui/SPARK-10996. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/DataFrame.R | 14 ++++++-------- R/pkg/R/generics.R | 6 +++++- R/pkg/R/sparkR.R | 12 +++--------- R/pkg/R/stats.R | 32 ++++++++++++++++++++++++++++++++ R/pkg/R/utils.R | 18 ++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++++ 7 files changed, 76 insertions(+), 19 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ed9cd94e03b13..52f7a0106aae6 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -65,6 +65,7 @@ exportMethods("arrange", "repartition", "sample", "sample_frac", + "sampleBy", "saveAsParquetFile", "saveAsTable", "saveDF", @@ -254,4 +255,4 @@ export("structField", "structType.structField", "print.structType") -export("as.data.frame") \ No newline at end of file +export("as.data.frame") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b7f5f978ebc2c..993be82a47f75 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1831,17 +1831,15 @@ setMethod("fillna", if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - - # Convert to the named list to an environment to be passed to JVM - valueMap <- new.env() - for (col in colNames) { - # Check each item in the named list is of valid type - v <- value[[col]] + # Check each item in the named list is of valid type + lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { stop("Each item in value should be an integer, numeric or charactor.") } - valueMap[[col]] <- v - } + }) + + # Convert to the named list to an environment to be passed to JVM + valueMap <- convertNamedListToEnv(value) # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c106a0024583e..4a419f785e92c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -509,6 +509,10 @@ setGeneric("sample", setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) +#' @rdname statfunctions +#' @export +setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) + #' @rdname saveAsParquetFile #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) @@ -1006,4 +1010,4 @@ setGeneric("as.data.frame") #' @rdname attach #' @export -setGeneric("attach") \ No newline at end of file +setGeneric("attach") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc47110f54732..9cf2f1a361cf2 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -163,19 +163,13 @@ sparkR.init <- function( sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkEnvirMap <- new.env() - for (varname in names(sparkEnvir)) { - sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] - } + sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) - sparkExecutorEnvMap <- new.env() - if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) + if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - for (varname in names(sparkExecutorEnv)) { - sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] - } nonEmptyJars <- Filter(function(x) { x != "" }, jars) localJarPaths <- lapply(nonEmptyJars, diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 4928cf4d4367d..f79329b115404 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) collect(dataFrame(sct)) }) + +#' sampleBy +#' +#' Returns a stratified sample without replacement based on the fraction given on each stratum. +#' +#' @param x A SparkSQL DataFrame +#' @param col column that defines strata +#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is +#' not specified, we treat its fraction as zero. +#' @param seed random seed +#' @return A new DataFrame that represents the stratified sample +#' +#' @rdname statfunctions +#' @name sampleBy +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' sample <- sampleBy(df, "key", fractions, 36) +#' } +setMethod("sampleBy", + signature(x = "DataFrame", col = "character", + fractions = "list", seed = "numeric"), + function(x, col, fractions, seed) { + fractionsEnv <- convertNamedListToEnv(fractions) + + statFunctions <- callJMethod(x@sdf, "stat") + # Seed is expected to be Long on Scala side, here convert it to an integer + # due to SerDe limitation now. + sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed)) + dataFrame(sdf) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 94f16c7ac52cc..0b9e2957fe9a5 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -605,3 +605,21 @@ structToList <- function(struct) { class(struct) <- "list" struct } + +# Convert a named list to an environment to be passed to JVM +convertNamedListToEnv <- function(namedList) { + # Make sure each item in the list has a name + names <- names(namedList) + stopifnot( + if (is.null(names)) { + length(namedList) == 0 + } else { + !any(is.na(names)) + }) + + env <- new.env() + for (name in names) { + env[[name]] <- namedList[[name]] + } + env +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 46cab7646dcf9..e1b42b0804933 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", { expect_identical(result[[2]], list(list(-1, -99))) }) +test_that("sampleBy() on a DataFrame", { + l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) + df <- createDataFrame(sqlContext, l, "key") + fractions <- list("0" = 0.1, "1" = 0.2) + sample <- sampleBy(df, "key", fractions, 0) + result <- collect(orderBy(count(groupBy(sample, "key")), "key")) + expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table Not Found: blah", retError), TRUE) From 135a2ce5b0b927b512c832d61c25e7b9d57e30be Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Wed, 14 Oct 2015 10:12:25 -0700 Subject: [PATCH 050/139] [SPARK-10619] Can't sort columns on Executor Page should pick into spark 1.5.2 also. https://issues.apache.org/jira/browse/SPARK-10619 looks like this was broken by commit: https://github.com/apache/spark/commit/fb1d06fc242ec00320f1a3049673fbb03c4a6eb9#diff-b8adb646ef90f616c34eb5c98d1ebd16 It looks like somethings were change to use the UIUtils.listingTable but executor page wasn't converted so when it removed sortable from the UIUtils. TABLE_CLASS_NOT_STRIPED it broke this page. Simply add the sortable tag back in and it fixes both active UI and the history server UI. Author: Tom Graves Closes #9101 from tgravescs/SPARK-10619. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 + .../src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- .../src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala | 2 +- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 21dc8f0b65485..68a9f912a5d2c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph private[spark] object UIUtils extends Logging { val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" + val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 01cddda4c62cd..1a29b0f412603 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -62,7 +62,7 @@ private[ui] class ExecutorsPage( val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty val execTable = -
Property NameDefaultMeaning
spark.storage.memoryFraction0.6 - Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" - generation of objects in the JVM, which by default is given 0.6 of the heap, but you can - increase it if you configure your own old generation size. -
spark.storage.memoryMapThreshold 2m
spark.storage.unrollFraction0.2 - Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory. - This is dynamically allocated by dropping existing blocks when there is not enough free - storage space to unroll the new block in its entirety. -
spark.externalBlockStore.blockManager org.apache.spark.storage.TachyonBlockManager
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index d5cdbfac104f8..be144f6065baa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -50,7 +50,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage hasBytesSpilled = data.hasBytesSpilled }) -
Executor ID Address
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1b717b64542d5..a19b85a51d289 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -443,7 +443,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { -
Executor ID Address
+
From 31f315981709251d5d26c508a3dc62cf0e6f87e1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 14 Oct 2015 10:25:09 -0700 Subject: [PATCH 051/139] [SPARK-11040] [NETWORK] Make sure SASL handler delegates all events. Author: Marcelo Vanzin Closes #9053 from vanzin/SPARK-11040. --- .../spark/network/sasl/SaslRpcHandler.java | 13 +++++++++++-- .../server/TransportRequestHandler.java | 8 +++++++- .../spark/network/sasl/SparkSaslSuite.java | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 3f2ebe32887b8..7033adb9cae6f 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -115,9 +115,18 @@ public StreamManager getStreamManager() { @Override public void connectionTerminated(TransportClient client) { - if (saslServer != null) { - saslServer.dispose(); + try { + delegate.connectionTerminated(client); + } finally { + if (saslServer != null) { + saslServer.dispose(); + } } } + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 96941d26be19d..9b8b047b49a86 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -76,7 +76,13 @@ public void exceptionCaught(Throwable cause) { @Override public void channelUnregistered() { - streamManager.connectionTerminated(channel); + if (streamManager != null) { + try { + streamManager.connectionTerminated(channel); + } catch (RuntimeException e) { + logger.error("StreamManager connectionTerminated() callback failed.", e); + } + } rpcHandler.connectionTerminated(reverseClient); } diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 8104004847a24..3469e84e7f4da 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -153,6 +153,8 @@ public Void answer(InvocationOnMock invocation) { assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); } finally { ctx.close(); + // There should be 2 terminated events; one for the client, one for the server. + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); } } @@ -334,6 +336,23 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { } } + @Test + public void testRpcHandlerDelegate() throws Exception { + // Tests all delegates exception for receive(), which is more complicated and already handled + // by all other tests. + RpcHandler handler = mock(RpcHandler.class); + RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null); + + saslHandler.getStreamManager(); + verify(handler).getStreamManager(); + + saslHandler.connectionTerminated(null); + verify(handler).connectionTerminated(any(TransportClient.class)); + + saslHandler.exceptionCaught(null, null); + verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); + } + private static class SaslTestCtx { final TransportClient client; From 7e1308d37f6ca35f063e67e4b87a77e932ad89a5 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 14 Oct 2015 12:31:29 -0700 Subject: [PATCH 052/139] [SPARK-8386] [SQL] add write.mode for insertIntoJDBC when the parm overwrite is false the fix is for jira https://issues.apache.org/jira/browse/SPARK-8386 Author: Huaxin Gao Closes #9042 from huaxingao/spark8386. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index bfe8d3c8ef957..174bc6f42ad8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1674,7 +1674,7 @@ class DataFrame private[sql]( */ @deprecated("Use write.jdbc()", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - val w = if (overwrite) write.mode(SaveMode.Overwrite) else write + val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) } From 615cc858cf913522059b6ebdde65f0204f4fb030 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Oct 2015 12:36:31 -0700 Subject: [PATCH 053/139] [SPARK-10973] Close #9064 Close #9063 Close #9062 These pull requests were merged into branch-1.5, branch-1.4, and branch-1.3. From cf2e0ae7205443f052463e8cb9334ae2b6df2d0e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Oct 2015 12:41:02 -0700 Subject: [PATCH 054/139] [SPARK-11096] Post-hoc review Netty based RPC implementation - round 2 A few more changes: 1. Renamed IDVerifier -> RpcEndpointVerifier 2. Renamed NettyRpcAddress -> RpcEndpointAddress 3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1 4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool. 5. Minor cleanup & documentation. Author: Reynold Xin Closes #9112 from rxin/SPARK-11096. --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 -- .../apache/spark/rpc/netty/Dispatcher.scala | 7 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 114 +++++++----------- ...Address.scala => RpcEndpointAddress.scala} | 32 ++--- ...rifier.scala => RpcEndpointVerifier.scala} | 21 ++-- .../rpc/netty/NettyRpcAddressSuite.scala | 2 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 - 7 files changed, 81 insertions(+), 107 deletions(-) rename core/src/main/scala/org/apache/spark/rpc/netty/{NettyRpcAddress.scala => RpcEndpointAddress.scala} (65%) rename core/src/main/scala/org/apache/spark/rpc/netty/{IDVerifier.scala => RpcEndpointVerifier.scala} (65%) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index ef491a0ae4f09..2c4a8b9a0a878 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -93,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } - /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` - * asynchronously. - */ - def asyncSetupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) - } - /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. * This is a blocking action. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 398e9eafc1444..f1a8273f157ef 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils +/** + * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + */ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private class EndpointData( @@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData]() + private val receivers = new LinkedBlockingQueue[EndpointData] /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced @@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) synchronized { if (stopped) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89b6df76c2707..a2b28c524df9c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,7 +22,6 @@ import java.nio.ByteBuffer import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag @@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = - SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) + // Override numConnectionsPerPeer to 1 for RPC. + private val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) @@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps: Seq[TransportClientBootstrap] = + val bootstraps: java.util.List[TransportClientBootstrap] = if (securityManager.isAuthenticationEnabled()) { - Seq(new SaslClientBootstrap(transportConf, "", securityManager, + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled())) } else { - Nil + java.util.Collections.emptyList[TransportClientBootstrap] } - transportContext.createClientFactory(bootstraps.asJava) + transportContext.createClientFactory(bootstraps) } val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") @@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv( // TODO: a non-blocking TransportClientFactory.createClient in future private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", - conf.getInt("spark.rpc.connect.threads", 256)) + conf.getInt("spark.rpc.connect.threads", 64)) @volatile private var server: TransportServer = _ @@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv( java.util.Collections.emptyList() } server = transportContext.createServer(port, bootstraps) - dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) + dispatcher.registerRpcEndpoint( + RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } override lazy val address: RpcAddress = { @@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv( } def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - val addr = NettyRpcAddress(uri) + val addr = RpcEndpointAddress(uri) val endpointRef = new NettyRpcEndpointRef(conf, addr, this) - val idVerifierRef = - new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) - idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => + val verifier = new NettyRpcEndpointRef( + conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) } else { @@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv( private[netty] def send(message: RequestMessage): Unit = { val remoteAddr = message.receiver.address if (remoteAddr == address) { + // Message to a local RPC endpoint. val promise = Promise[Any]() dispatcher.postLocalMessage(message, promise) promise.future.onComplete { case Success(response) => val ack = response.asInstanceOf[Ack] - logDebug(s"Receive ack from ${ack.sender}") + logTrace(s"Received ack from ${ack.sender}") case Failure(e) => logError(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { + // Message to a remote RPC endpoint. try { // `createClient` will block if it cannot find a known connection, so we should run it in // clientConnectionExecutor @@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv( } }) } catch { - case e: RejectedExecutionException => { + case e: RejectedExecutionException => if (!promise.tryFailure(e)) { logWarning(s"Ignore failure", e) } - } } } promise.future @@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv( } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new NettyRpcAddress(address.host, address.port, endpointName).toString + new RpcEndpointAddress(address.host, address.port, endpointName).toString override def shutdown(): Unit = { cleanup() @@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) @transient @volatile private var nettyEnv: NettyRpcEnv = _ - @transient @volatile private var _address: NettyRpcAddress = _ + @transient @volatile private var _address: RpcEndpointAddress = _ - def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { + def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { this(conf) this._address = _address this.nettyEnv = nettyEnv @@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - _address = in.readObject().asInstanceOf[NettyRpcAddress] + _address = in.readObject().asInstanceOf[RpcEndpointAddress] nettyEnv = NettyRpcEnv.currentEnv.value } @@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler( private type RemoteEnvAddress = RpcAddress // Store all client addresses and their NettyRpcEnv addresses. + // TODO: Is this even necessary? @GuardedBy("this") private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() - // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection - // count because `TransportClientFactory.createClient` will create multiple connections - // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection - // to send the message. See `TransportClientFactory.createClient` for more details. - @GuardedBy("this") - private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() - override def receive( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage: Option[RemoteProcessConnected] = - synchronized { - // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - // Increase the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count + 1) - if (count == 0) { - // This is the first connection, so fire "Associated" - Some(RemoteProcessConnected(remoteEnvAddress)) - } else { - None - } - } else { - None - } + + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + var dispatchRemoteProcessConnected = false + synchronized { + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time, fire "RemoteProcessConnected" + dispatchRemoteProcessConnected = true } - broadcastMessage.foreach(dispatcher.postToAll) + } + if (dispatchRemoteProcessConnected) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } dispatcher.postRemoteMessage(requestMessage, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) val broadcastMessage = @@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler( } override def connectionTerminated(client: TransportClient): Unit = { - val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the last connection to a remote RpcEnv is terminated, we should broadcast - // "Disassociated" - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") - if (count - 1 == 0) { - // We lost all clients, so clean up and fire "Disassociated" - remoteConnectionCount.remove(remoteEnvAddress) - Some(RemoteProcessDisconnected(remoteEnvAddress)) - } else { - // Decrease the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count - 1) - None - } - } + val messageOpt: Option[RemoteProcessDisconnected] = + synchronized { + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + Some(RemoteProcessDisconnected(remoteEnvAddress)) } - broadcastMessage.foreach(dispatcher.postToAll) + } + messageOpt.foreach(dispatcher.postToAll) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". // See java.net.Socket.getRemoteSocketAddress } } - } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala rename to core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index 1876b25592086..87b6236936817 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -17,40 +17,44 @@ package org.apache.spark.rpc.netty -import java.net.URI - import org.apache.spark.SparkException import org.apache.spark.rpc.RpcAddress -private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { +/** + * An address identifier for an RPC endpoint. + * + * @param host host name of the remote process. + * @param port the port the remote RPC environment binds to. + * @param name name of the remote endpoint. + */ +private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { def toRpcAddress: RpcAddress = RpcAddress(host, port) override val toString = s"spark://$name@$host:$port" } -private[netty] object NettyRpcAddress { +private[netty] object RpcEndpointAddress { - def apply(sparkUrl: String): NettyRpcAddress = { + def apply(sparkUrl: String): RpcEndpointAddress = { try { - val uri = new URI(sparkUrl) + val uri = new java.net.URI(sparkUrl) val host = uri.getHost val port = uri.getPort val name = uri.getUserInfo if (uri.getScheme != "spark" || - host == null || - port < 0 || - name == null || - (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null - uri.getFragment != null || - uri.getQuery != null) { + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { throw new SparkException("Invalid Spark URL: " + sparkUrl) } - NettyRpcAddress(host, port, name) + RpcEndpointAddress(host, port, name) } catch { case e: java.net.URISyntaxException => throw new SparkException("Invalid Spark URL: " + sparkUrl, e) } } - } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala rename to core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index fa9a3eb99b02a..99f20da2d66aa 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -14,26 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.rpc.netty import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists - */ -private[netty] case class ID(name: String) - -/** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * + * This is used when setting up a remote endpoint reference. */ -private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) +private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case ID(name) => context.reply(dispatcher.verify(name)) + case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name)) } } -private[netty] object IDVerifier { - val NAME = "id-verifier" +private[netty] object RpcEndpointVerifier { + val NAME = "endpoint-verifier" + + /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + case class CheckExistence(name: String) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index a5d43d3704e37..973a07a0bde3a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = NettyRpcAddress("localhost", 12345, "test") + val addr = RpcEndpointAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f24f78b8c4542..5430e4c0c4d6c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) - when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) - nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) } From 9a430a027faafb083ca569698effb697af26a1db Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 14 Oct 2015 15:08:13 -0700 Subject: [PATCH 055/139] [SPARK-11068] [SQL] [FOLLOW-UP] move execution listener to util Author: Wenchen Fan Closes #9119 from cloud-fan/callback. --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 1 + .../apache/spark/sql/{ => util}/QueryExecutionListener.scala | 2 +- .../apache/spark/sql/{ => util}/DataFrameCallbackSuite.scala | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => util}/QueryExecutionListener.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => util}/DataFrameCallbackSuite.scala (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a835408f8af3a..3d5e35ab315eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.{execution => sparkexecution} +import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 14fbebb45f8b7..909a8abd225b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 4e286a0076205..eb056cd519717 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.util import org.apache.spark.SparkException +import org.apache.spark.sql.{functions, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext From 56d7da14ab8f89bf4f303b27f51fd22d23967ffb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 14 Oct 2015 16:05:37 -0700 Subject: [PATCH 056/139] [SPARK-10104] [SQL] Consolidate different forms of table identifiers Right now, we have QualifiedTableName, TableIdentifier, and Seq[String] to represent table identifiers. We should only have one form and TableIdentifier is the best one because it provides methods to get table name, database name, return unquoted string, and return quoted string. Author: Wenchen Fan Author: Wenchen Fan Closes #8453 from cloud-fan/table-name. --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../spark/sql/catalyst/TableIdentifier.scala | 14 +- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../spark/sql/catalyst/analysis/Catalog.scala | 174 ++++++------------ .../sql/catalyst/analysis/unresolved.scala | 6 +- .../spark/sql/catalyst/dsl/package.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 24 ++- .../sql/catalyst/analysis/AnalysisTest.scala | 10 +- .../analysis/DecimalPrecisionSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../sql/execution/datasources/DDLParser.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 7 +- .../sql/execution/datasources/rules.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 7 +- .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../apache/spark/sql/ListTablesSuite.scala | 7 +- .../parquet/ParquetQuerySuite.scala | 6 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 134 ++++---------- .../org/apache/spark/sql/hive/HiveQl.scala | 42 ++--- .../hive/execution/CreateTableAsSelect.scala | 12 +- .../hive/execution/CreateViewAsSelect.scala | 9 +- .../spark/sql/hive/execution/commands.scala | 10 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../hive/JavaMetastoreDataSourcesSuite.java | 6 +- .../spark/sql/hive/ListTablesSuite.scala | 5 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 9 +- .../spark/sql/hive/StatisticsSuite.scala | 5 +- .../sql/hive/execution/SQLQuerySuite.scala | 6 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 5 +- 32 files changed, 212 insertions(+), 327 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index dfab2398857e8..2595e1f90c837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -170,7 +170,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ { + ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ { case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) } | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index d701559bf2d9b..4d4e4ded99477 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -20,14 +20,16 @@ package org.apache.spark.sql.catalyst /** * Identifies a `table` in `database`. If `database` is not defined, the current database is used. */ -private[sql] case class TableIdentifier(table: String, database: Option[String] = None) { - def withDatabase(database: String): TableIdentifier = this.copy(database = Some(database)) - - def toSeq: Seq[String] = database.toSeq :+ table +private[sql] case class TableIdentifier(table: String, database: Option[String]) { + def this(table: String) = this(table, None) override def toString: String = quotedString - def quotedString: String = toSeq.map("`" + _ + "`").mkString(".") + def quotedString: String = database.map(db => s"`$db`.`$table`").getOrElse(s"`$table`") + + def unquotedString: String = database.map(db => s"$db.$table").getOrElse(table) +} - def unquotedString: String = toSeq.mkString(".") +private[sql] object TableIdentifier { + def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 041ab22827399..e6046055bf0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -105,7 +105,7 @@ class Analyzer( // here use the CTE definition first, check table name only and ignore database name // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info case u : UnresolvedRelation => - val substituted = cteRelations.get(u.tableIdentifier.last).map { relation => + val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => val withAlias = u.alias.map(Subquery(_, relation)) withAlias.getOrElse(relation) } @@ -257,7 +257,7 @@ class Analyzer( catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"no such table ${u.tableName}") + u.failAnalysis(s"Table Not Found: ${u.tableName}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 4cc9a5520a085..8f4ce74a2ea38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -42,11 +42,9 @@ trait Catalog { val conf: CatalystConf - def tableExists(tableIdentifier: Seq[String]): Boolean + def tableExists(tableIdent: TableIdentifier): Boolean - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan + def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan /** * Returns tuples of (tableName, isTemporary) for all tables in the given database. @@ -56,89 +54,59 @@ trait Catalog { def refreshTable(tableIdent: TableIdentifier): Unit - // TODO: Refactor it in the work of SPARK-10104 - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit + def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit - // TODO: Refactor it in the work of SPARK-10104 - def unregisterTable(tableIdentifier: Seq[String]): Unit + def unregisterTable(tableIdent: TableIdentifier): Unit def unregisterAllTables(): Unit - // TODO: Refactor it in the work of SPARK-10104 - protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { - if (conf.caseSensitiveAnalysis) { - tableIdentifier - } else { - tableIdentifier.map(_.toLowerCase) - } - } - - // TODO: Refactor it in the work of SPARK-10104 - protected def getDbTableName(tableIdent: Seq[String]): String = { - val size = tableIdent.size - if (size <= 2) { - tableIdent.mkString(".") - } else { - tableIdent.slice(size - 2, size).mkString(".") - } - } - - // TODO: Refactor it in the work of SPARK-10104 - protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { - (tableIdent.lift(tableIdent.size - 2), tableIdent.last) - } - /** - * It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]]. - * We use this method to check it. + * Get the table name of TableIdentifier for temporary tables. */ - protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = { - if (tableIdentifier.length > 1) { + protected def getTableName(tableIdent: TableIdentifier): String = { + // It is not allowed to specify database name for temporary tables. + // We check it here and throw exception if database is defined. + if (tableIdent.database.isDefined) { throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + "for temporary tables. If the table name has dots (.) in it, please quote the " + "table name with backticks (`).") } + if (conf.caseSensitiveAnalysis) { + tableIdent.table + } else { + tableIdent.table.toLowerCase + } } } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new ConcurrentHashMap[String, LogicalPlan] - - override def registerTable( - tableIdentifier: Seq[String], - plan: LogicalPlan): Unit = { - checkTableIdentifier(tableIdentifier) - val tableIdent = processTableIdentifier(tableIdentifier) - tables.put(getDbTableName(tableIdent), plan) + private[this] val tables = new ConcurrentHashMap[String, LogicalPlan] + + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { + tables.put(getTableName(tableIdent), plan) } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - checkTableIdentifier(tableIdentifier) - val tableIdent = processTableIdentifier(tableIdentifier) - tables.remove(getDbTableName(tableIdent)) + override def unregisterTable(tableIdent: TableIdentifier): Unit = { + tables.remove(getTableName(tableIdent)) } override def unregisterAllTables(): Unit = { tables.clear() } - override def tableExists(tableIdentifier: Seq[String]): Boolean = { - checkTableIdentifier(tableIdentifier) - val tableIdent = processTableIdentifier(tableIdentifier) - tables.containsKey(getDbTableName(tableIdent)) + override def tableExists(tableIdent: TableIdentifier): Boolean = { + tables.containsKey(getTableName(tableIdent)) } override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - checkTableIdentifier(tableIdentifier) - val tableIdent = processTableIdentifier(tableIdentifier) - val tableFullName = getDbTableName(tableIdent) - val table = tables.get(tableFullName) + val tableName = getTableName(tableIdent) + val table = tables.get(tableName) if (table == null) { - sys.error(s"Table Not Found: $tableFullName") + throw new NoSuchTableException } - val tableWithQualifiers = Subquery(tableIdent.last, table) + val tableWithQualifiers = Subquery(tableName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. @@ -146,11 +114,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val result = ArrayBuffer.empty[(String, Boolean)] - for (name <- tables.keySet().asScala) { - result += ((name, true)) - } - result + tables.keySet().asScala.map(_ -> true).toSeq } override def refreshTable(tableIdent: TableIdentifier): Unit = { @@ -165,68 +129,50 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { * lost when the JVM exits. */ trait OverrideCatalog extends Catalog { + private[this] val overrides = new ConcurrentHashMap[String, LogicalPlan] - // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() - - abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { - val tableIdent = processTableIdentifier(tableIdentifier) - // A temporary tables only has a single part in the tableIdentifier. - val overriddenTable = if (tableIdentifier.length > 1) { - None: Option[LogicalPlan] + private def getOverriddenTable(tableIdent: TableIdentifier): Option[LogicalPlan] = { + if (tableIdent.database.isDefined) { + None } else { - overrides.get(getDBTable(tableIdent)) + Option(overrides.get(getTableName(tableIdent))) } - overriddenTable match { + } + + abstract override def tableExists(tableIdent: TableIdentifier): Boolean = { + getOverriddenTable(tableIdent) match { case Some(_) => true - case None => super.tableExists(tableIdentifier) + case None => super.tableExists(tableIdent) } } abstract override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - val tableIdent = processTableIdentifier(tableIdentifier) - // A temporary tables only has a single part in the tableIdentifier. - val overriddenTable = if (tableIdentifier.length > 1) { - None: Option[LogicalPlan] - } else { - overrides.get(getDBTable(tableIdent)) - } - val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) + getOverriddenTable(tableIdent) match { + case Some(table) => + val tableName = getTableName(tableIdent) + val tableWithQualifiers = Subquery(tableName, table) - // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are - // properly qualified with this alias. - val withAlias = - tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) + // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes + // are properly qualified with this alias. + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) + case None => super.lookupRelation(tableIdent, alias) + } } abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - // We always return all temporary tables. - val temporaryTables = overrides.map { - case ((_, tableName), _) => (tableName, true) - }.toSeq - - temporaryTables ++ super.getTables(databaseName) + overrides.keySet().asScala.map(_ -> true).toSeq ++ super.getTables(databaseName) } - override def registerTable( - tableIdentifier: Seq[String], - plan: LogicalPlan): Unit = { - checkTableIdentifier(tableIdentifier) - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.put(getDBTable(tableIdent), plan) + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { + overrides.put(getTableName(tableIdent), plan) } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - // A temporary tables only has a single part in the tableIdentifier. - // If tableIdentifier has more than one parts, it is not a temporary table - // and we do not need to do anything at here. - if (tableIdentifier.length == 1) { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.remove(getDBTable(tableIdent)) + override def unregisterTable(tableIdent: TableIdentifier): Unit = { + if (tableIdent.database.isEmpty) { + overrides.remove(getTableName(tableIdent)) } } @@ -243,12 +189,12 @@ object EmptyCatalog extends Catalog { override val conf: CatalystConf = EmptyConf - override def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdent: TableIdentifier): Boolean = { throw new UnsupportedOperationException } override def lookupRelation( - tableIdentifier: Seq[String], + tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan = { throw new UnsupportedOperationException } @@ -257,15 +203,17 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } - override def unregisterAllTables(): Unit = {} + override def unregisterAllTables(): Unit = { + throw new UnsupportedOperationException + } override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 43ee3191935eb..c97365003935e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.errors +import org.apache.spark.sql.catalyst.{TableIdentifier, errors} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode @@ -36,11 +36,11 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. */ case class UnresolvedRelation( - tableIdentifier: Seq[String], + tableIdentifier: TableIdentifier, alias: Option[String] = None) extends LeafNode { /** Returns a `.` separated name for this relation. */ - def tableName: String = tableIdentifier.mkString(".") + def tableName: String = tableIdentifier.unquotedString override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 699c4cc63d09a..27b3cd84b3846 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -286,7 +286,8 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + analysis.UnresolvedRelation(TableIdentifier(tableName)), + Map.empty, logicalPlan, overwrite, false) def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 820b336aac759..ec05cfa63c5bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -53,32 +54,39 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output, testRelation)) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation)) assertAnalysisError( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Seq("cannot resolve")) checkAnalysis( - Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( - Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation( + TableIdentifier("TaBlE"), Some("TbL"))), Project(testRelation.output, testRelation), caseSensitive = false) } test("resolve relations") { - assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) + assertAnalysisError( + UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table Not Found: tAbLe")) - checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) + checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) - checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) + checkAnalysis( + UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) - checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) + checkAnalysis( + UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) } test("divide should be casted into fractional types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 53b3695a86be5..23861ed15da61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} trait AnalysisTest extends PlanTest { @@ -30,8 +31,8 @@ trait AnalysisTest extends PlanTest { val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) + caseSensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) + caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { override val extendedResolutionRules = EliminateSubQueries :: Nil @@ -67,8 +68,7 @@ trait AnalysisTest extends PlanTest { expectedErrors: Seq[String], caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - // todo: make sure we throw AnalysisException during analysis - val e = intercept[Exception] { + val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index b4ad618c23e39..40c4ae7920918 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) @@ -47,7 +47,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val b: Expression = UnresolvedAttribute("b") before { - catalog.registerTable(Seq("table"), relation) + catalog.registerTable(TableIdentifier("table"), relation) } private def checkType(expression: Expression, expectedType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 97a8b6518a832..eacdea2c1e5b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, Partition} +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -287,7 +288,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(Seq(tableName))) + DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03e973666e888..764510ab4b4bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -171,7 +171,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val overwrite = mode == SaveMode.Overwrite df.sqlContext.executePlan( InsertIntoTable( - UnresolvedRelation(tableIdent.toSeq), + UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), df.logicalPlan, overwrite, @@ -201,7 +201,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val tableExists = df.sqlContext.catalog.tableExists(tableIdent.toSeq) + val tableExists = df.sqlContext.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3d5e35ab315eb..361eb576c567a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -714,7 +714,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), df.logicalPlan) + catalog.registerTable(TableIdentifier(tableName), df.logicalPlan) } /** @@ -728,7 +728,7 @@ class SQLContext private[sql]( */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(Seq(tableName)) + catalog.unregisterTable(TableIdentifier(tableName)) } /** @@ -795,7 +795,7 @@ class SQLContext private[sql]( } private def table(tableIdent: TableIdentifier): DataFrame = { - DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) + DataFrame(this, catalog.lookupRelation(tableIdent)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index f7a88b98c0b48..446739d5b8a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -140,7 +140,7 @@ class DDLParser(parseQuery: String => LogicalPlan) protected lazy val describeTable: Parser[LogicalPlan] = (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { case e ~ tableIdent => - DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined) + DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) } protected lazy val refreshTable: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 31d6b75e13477..e7deeff13dc4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -71,7 +71,6 @@ case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). case class CreateTableUsingAsSelect( tableIdent: TableIdentifier, provider: String, @@ -93,7 +92,7 @@ case class CreateTempTableUsing( val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.catalog.registerTable( - tableIdent.toSeq, + tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] @@ -112,7 +111,7 @@ case class CreateTempTableUsingAsSelect( val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.catalog.registerTable( - tableIdent.toSeq, + tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] @@ -128,7 +127,7 @@ case class RefreshTable(tableIdent: TableIdentifier) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent.toSeq) + val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 8efc8016f94dd..b00e5680fef9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -143,9 +143,9 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent.toSeq)) { + if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) match { + EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 356d4ff3fa837..fd566c8276bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.execution.PhysicalRDD import scala.concurrent.duration._ @@ -287,8 +288,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { testData.select('key).registerTempTable("t1") sqlContext.table("t1") sqlContext.dropTempTable("t1") - assert( - intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) + intercept[NoSuchTableException](sqlContext.table("t1")) } test("Drops cached temporary table") { @@ -300,8 +300,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert(sqlContext.isCached("t2")) sqlContext.dropTempTable("t1") - assert( - intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) + intercept[NoSuchTableException](sqlContext.table("t1")) assert(!sqlContext.isCached("t2")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 7a027e13089e3..b1fb06815868c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.SharedSQLContext @@ -359,8 +360,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { upperCaseData.where('N <= 4).registerTempTable("left") upperCaseData.where('N >= 3).registerTempTable("right") - val left = UnresolvedRelation(Seq("left"), None) - val right = UnresolvedRelation(Seq("right"), None) + val left = UnresolvedRelation(TableIdentifier("left"), None) + val right = UnresolvedRelation(TableIdentifier("right"), None) checkAnswer( left.join(right, $"left.N" === $"right.N", "full"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index eab0fbb196eb6..5688f46e5e3d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.TableIdentifier class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ @@ -32,7 +33,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) } test("get all tables") { @@ -44,7 +45,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } @@ -57,7 +58,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index cc02ef81c9f8b..baff7f5752a75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{TableIdentifier, InternalRow} import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext @@ -49,7 +49,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) } test("overwriting") { @@ -59,7 +59,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) } test("self-join") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index e620d7fb82af9..4d8a3f728e6b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -358,7 +358,7 @@ class HiveContext private[hive]( @Experimental def analyze(tableName: String) { val tableIdent = SqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) relation match { case relation: MetastoreRelation => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1f8223e1ff507..5819cb9d08778 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.execution.{FileRelation, datasources} @@ -103,10 +103,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive /** Usages should lock on `this`. */ protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) - // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ - case class QualifiedTableName(database: String, name: String) { - def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) + case class QualifiedTableName(database: String, name: String) + + private def getQualifiedTableName(tableIdent: TableIdentifier) = { + QualifiedTableName( + tableIdent.database.getOrElse(client.currentDatabase).toLowerCase, + tableIdent.table.toLowerCase) + } + + private def getQualifiedTableName(hiveTable: HiveTable) = { + QualifiedTableName( + hiveTable.specifiedDatabase.getOrElse(client.currentDatabase).toLowerCase, + hiveTable.name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -179,33 +188,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def invalidateTable(tableIdent: TableIdentifier): Unit = { - val databaseName = tableIdent.database.getOrElse(client.currentDatabase) - val tableName = tableIdent.table - - cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) - } - - val caseSensitive: Boolean = false - - /** - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - */ - // TODO: Remove this in SPARK-10104. - def createDataSourceTable( - tableName: String, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String], - isExternal: Boolean): Unit = { - createDataSourceTable( - SqlParser.parseTableIdentifier(tableName), - userSpecifiedSchema, - partitionColumns, - provider, - options, - isExternal) + cachedDataSourceTables.invalidate(getQualifiedTableName(tableIdent)) } def createDataSourceTable( @@ -215,10 +198,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = { - val database = tableIdent.database.getOrElse(client.currentDatabase) - processDatabaseAndTableName(database, tableIdent.table) - } + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) val tableProperties = new mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -311,7 +291,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // TODO: Support persisting partitioned data source relations in Hive compatible format val qualifiedTableName = tableIdent.quotedString - val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { + val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) @@ -349,9 +329,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive (None, message) } - (hiveCompitiableTable, logMessage) match { + (hiveCompatibleTable, logMessage) match { case (Some(table), message) => - // We first try to save the metadata of the table in a Hive compatiable way. + // We first try to save the metadata of the table in a Hive compatible way. // If Hive throws an error, we fall back to save its metadata in the Spark SQL // specific way. try { @@ -374,48 +354,29 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) - } - def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) - val database = tableIdent.database.getOrElse(client.currentDatabase) - - new Path( - new Path(client.getDatabase(database).location), - tableIdent.table.toLowerCase).toString + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) + new Path(new Path(client.getDatabase(dbName).location), tblName).toString } - def tableExists(tableIdentifier: Seq[String]): Boolean = { - val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = - tableIdent - .lift(tableIdent.size - 2) - .getOrElse(client.currentDatabase) - val tblName = tableIdent.last - client.getTableOption(databaseName, tblName).isDefined + override def tableExists(tableIdent: TableIdentifier): Boolean = { + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) + client.getTableOption(dbName, tblName).isDefined } - def lookupRelation( - tableIdentifier: Seq[String], + override def lookupRelation( + tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = { - val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - client.currentDatabase) - val tblName = tableIdent.last - val table = client.getTable(databaseName, tblName) + val qualifiedTableName = getQualifiedTableName(tableIdent) + val table = client.getTable(qualifiedTableName.database, qualifiedTableName.name) if (table.properties.get("spark.sql.sources.provider").isDefined) { - val dataSourceTable = - cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + val dataSourceTable = cachedDataSourceTables(qualifiedTableName) + val tableWithQualifiers = Subquery(qualifiedTableName.name, dataSourceTable) // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. - val withAlias = - alias.map(a => Subquery(a, dataSourceTable)).getOrElse( - Subquery(tableIdent.last, dataSourceTable)) - - withAlias + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } else if (table.tableType == VirtualView) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { @@ -425,7 +386,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) } } else { - MetastoreRelation(databaseName, tblName, alias)(table)(hive) + MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) } } @@ -524,26 +485,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive client.listTables(db).map(tableName => (tableName, false)) } - protected def processDatabaseAndTableName( - databaseName: Option[String], - tableName: String): (Option[String], String) = { - if (!caseSensitive) { - (databaseName.map(_.toLowerCase), tableName.toLowerCase) - } else { - (databaseName, tableName) - } - } - - protected def processDatabaseAndTableName( - databaseName: String, - tableName: String): (String, String) = { - if (!caseSensitive) { - (databaseName.toLowerCase, tableName.toLowerCase) - } else { - (databaseName, tableName) - } - } - /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. @@ -597,8 +538,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") } - val (dbName, tblName) = processDatabaseAndTableName( - table.specifiedDatabase.getOrElse(client.currentDatabase), table.name) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateViewAsSelect( table.copy( @@ -636,7 +576,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( TableIdentifier(desc.name), - hive.conf.defaultDataSourceName, + conf.defaultDataSourceName, temporary = false, Array.empty[String], mode, @@ -652,9 +592,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive table } - val (dbName, tblName) = - processDatabaseAndTableName( - desc.specifiedDatabase.getOrElse(client.currentDatabase), desc.name) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateTableAsSelect( desc.copy( @@ -712,7 +650,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } @@ -720,7 +658,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 1d505019400bc..d4ff5cc0f12a2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ @@ -442,24 +443,12 @@ private[hive] object HiveQl extends Logging { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { - val (db, tableName) = - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } - - (db, tableName) - } - - protected def extractTableIdent(tableNameParts: Node): Seq[String] = { + protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { tableNameParts.getChildren.asScala.map { case Token(part, Nil) => cleanIdentifier(part) } match { - case Seq(tableOnly) => Seq(tableOnly) - case Seq(databaseName, table) => Seq(databaseName, table) + case Seq(tableOnly) => TableIdentifier(tableOnly) + case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) case other => sys.error("Hive only supports tables names like 'tableName' " + s"or 'databaseName.tableName', found '$other'") } @@ -518,13 +507,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C properties: Map[String, String], allowExist: Boolean, replace: Boolean): CreateViewAsSelect = { - val (db, viewName) = extractDbNameTableName(viewNameParts) + val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) val originalText = context.getTokenRewriteStream .toString(query.getTokenStartIndex, query.getTokenStopIndex) val tableDesc = HiveTable( - specifiedDatabase = db, + specifiedDatabase = dbName, name = viewName, schema = schema, partitionColumns = Seq.empty[HiveColumn], @@ -611,7 +600,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(Seq(tableName.getText), None), isExtended = extended.isDefined) + UnresolvedRelation(TableIdentifier(tableName.getText), None), + isExtended = extended.isDefined) } } // All other cases. @@ -716,12 +706,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLELOCATION", "TOK_TABLEPROPERTIES"), children) - val (db, tableName) = extractDbNameTableName(tableNameParts) + val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) // TODO add bucket support var tableDesc: HiveTable = HiveTable( - specifiedDatabase = db, - name = tableName, + specifiedDatabase = dbName, + name = tblName, schema = Seq.empty[HiveColumn], partitionColumns = Seq.empty[HiveColumn], properties = Map[String, String](), @@ -1264,15 +1254,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nonAliasClauses) } - val tableIdent = - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => Seq(tableOnly) - case Seq(databaseName, table) => Seq(databaseName, table) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } + val tableIdent = extractTableIdent(tableNameParts) val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } val relation = UnresolvedRelation(tableIdent, alias) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 8422287e177e5..e72a60b42e653 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} @@ -37,8 +38,7 @@ case class CreateTableAsSelect( allowExisting: Boolean) extends RunnableCommand { - def database: String = tableDesc.database - def tableName: String = tableDesc.name + val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) override def children: Seq[LogicalPlan] = Seq(query) @@ -72,18 +72,18 @@ case class CreateTableAsSelect( hiveContext.catalog.client.createTable(withSchema) // Get the Metastore Relation - hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match { + hiveContext.catalog.lookupRelation(tableIdentifier, None) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.catalog.tableExists(Seq(database, tableName))) { + if (hiveContext.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { - throw new AnalysisException(s"$database.$tableName already exists.") + throw new AnalysisException(s"$tableIdentifier already exists.") } } else { hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd @@ -93,6 +93,6 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]" + s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name}, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 2b504ac974f07..2c81115ee4fed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} import org.apache.spark.sql.{AnalysisException, Row, SQLContext} @@ -38,18 +39,18 @@ private[hive] case class CreateViewAsSelect( assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) assert(tableDesc.viewText.isDefined) + val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - val database = tableDesc.database - val viewName = tableDesc.name - if (hiveContext.catalog.tableExists(Seq(database, viewName))) { + if (hiveContext.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // view already exists, will do nothing, to keep consistent with Hive } else if (orReplace) { hiveContext.catalog.client.alertView(prepareTable()) } else { - throw new AnalysisException(s"View $database.$viewName already exists. " + + throw new AnalysisException(s"View $tableIdentifier already exists. " + "If you want to update the view definition, please use ALTER VIEW AS or " + "CREATE OR REPLACE VIEW AS") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 51ec92afd06ed..94210a5394f9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -71,7 +71,7 @@ case class DropTable( } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.catalog.unregisterTable(Seq(tableName)) + hiveContext.catalog.unregisterTable(TableIdentifier(tableName)) Seq.empty[Row] } } @@ -103,7 +103,6 @@ case class AddFile(path: String) extends RunnableCommand { } } -// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSource( tableIdent: TableIdentifier, @@ -131,7 +130,7 @@ case class CreateMetastoreDataSource( val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { + if (hiveContext.catalog.tableExists(tableIdent)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -160,7 +159,6 @@ case class CreateMetastoreDataSource( } } -// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, @@ -198,7 +196,7 @@ case class CreateMetastoreDataSourceAsSelect( } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { + if (sqlContext.catalog.tableExists(tableIdent)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -215,7 +213,7 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => if (l.relation != createdRelation.relation) { val errorDescription = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ff39ccb7c1ea5..6883d305cbead 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -191,7 +191,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } val referencedTestTables = referencedTables.filter(testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index c8d272794d10b..8c4af1b8eaf44 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.spark.sql.SaveMode; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -41,6 +40,8 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.util.Utils; public class JavaMetastoreDataSourcesSuite { @@ -71,7 +72,8 @@ public void setUp() throws IOException { if (path.exists()) { path.delete(); } - hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath( + new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); if (fs.exists(hiveManagedPath)){ fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 579631df772b5..183aca29cf98d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row @@ -31,14 +32,14 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. - catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(TableIdentifier("ListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d3565380005a0..d2928876887bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.util.Utils /** @@ -367,7 +368,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val expectedPath = catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) @@ -472,7 +473,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Drop table will also delete the data. sql("DROP TABLE savedJsonTable") intercept[IOException] { - read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + read.json(catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) } } @@ -703,7 +704,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Manually create a metastore data source table. catalog.createDataSourceTable( - tableName = "wide_schema", + tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], provider = "json", @@ -733,7 +734,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv "EXTERNAL" -> "FALSE"), tableType = ManagedTable, serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) + "path" -> catalog.hiveDefaultTableFilePath(TableIdentifier(tableName)))) catalog.client.createTable(hiveTable) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 6a692d6fce562..9bb32f11b76bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -68,7 +69,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - hiveContext.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + hiveContext.catalog.lookupRelation(TableIdentifier(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -115,7 +116,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { intercept[UnsupportedOperationException] { hiveContext.analyze("tempTable") } - hiveContext.catalog.unregisterTable(Seq("tempTable")) + hiveContext.catalog.unregisterTable(TableIdentifier("tempTable")) } test("estimates the size of a test MetastoreRelation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6aa34605b05a8..c929ba50680bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.{TableIdentifier, DefaultParserDialect} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -266,7 +266,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(TableIdentifier(tableName))) relation match { case LogicalRelation(r: ParquetRelation, _) => if (!isDataSourceParquet) { @@ -723,7 +723,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { (1 to 100).par.map { i => val tableName = s"SPARK_6618_table_$i" sql(s"CREATE TABLE $tableName (col1 string)") - catalog.lookupRelation(Seq(tableName)) + catalog.lookupRelation(TableIdentifier(tableName)) table(tableName) tables() sql(s"DROP TABLE $tableName") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 5eb39b1129701..7efeab528c1dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.hive.ql.io.orc.CompressionKind import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -218,7 +219,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(TableIdentifier("tmp")) } test("overwriting") { @@ -228,7 +229,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(TableIdentifier("tmp")) } test("self-join") { From 2b5e31c7e97811ef7b4da47609973b7f51444346 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Oct 2015 16:27:43 -0700 Subject: [PATCH 057/139] [SPARK-11113] [SQL] Remove DeveloperApi annotation from private classes. o.a.s.sql.catalyst and o.a.s.sql.execution are supposed to be private. Author: Reynold Xin Closes #9121 from rxin/SPARK-11113. --- .../expressions/codegen/package.scala | 3 -- .../spark/sql/execution/Aggregate.scala | 3 -- .../apache/spark/sql/execution/Exchange.scala | 5 +--- .../spark/sql/execution/ExistingRDD.scala | 6 +--- .../apache/spark/sql/execution/Expand.scala | 2 -- .../apache/spark/sql/execution/Generate.scala | 3 -- .../spark/sql/execution/LocalTableScan.scala | 3 +- .../spark/sql/execution/QueryExecution.scala | 7 ++--- .../spark/sql/execution/ShuffledRowRDD.scala | 1 - .../spark/sql/execution/SparkPlan.scala | 6 ++-- .../apache/spark/sql/execution/Window.scala | 10 ++----- .../spark/sql/execution/basicOperators.scala | 28 ++---------------- .../apache/spark/sql/execution/commands.scala | 29 +++---------------- .../execution/joins/BroadcastHashJoin.scala | 3 -- .../joins/BroadcastHashOuterJoin.scala | 3 -- .../joins/BroadcastLeftSemiJoinHash.scala | 3 -- .../joins/BroadcastNestedLoopJoin.scala | 6 +--- .../execution/joins/CartesianProduct.scala | 6 +--- .../sql/execution/joins/HashOuterJoin.scala | 9 ++---- .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 -- .../execution/joins/LeftSemiJoinHash.scala | 3 -- .../execution/joins/ShuffledHashJoin.scala | 3 -- .../joins/ShuffledHashOuterJoin.scala | 3 -- .../sql/execution/joins/SortMergeJoin.scala | 3 -- .../execution/joins/SortMergeOuterJoin.scala | 3 -- .../spark/sql/execution/joins/package.scala | 6 ---- .../apache/spark/sql/execution/python.scala | 7 +---- .../sql/execution/rowFormatConverters.scala | 5 ---- .../spark/sql/test/ExamplePointUDT.scala | 3 -- 29 files changed, 22 insertions(+), 153 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 606fecbe06e47..41128fe389d46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.rules import org.apache.spark.util.Utils @@ -40,10 +39,8 @@ package object codegen { } /** - * :: DeveloperApi :: * Dumps the bytecode from a class to the screen using javap. */ - @DeveloperApi object DumpByteCode { import scala.sys.process._ val dumpDirectory = Utils.createTempDir() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index f3b6a3a5f4a33..6f3f1bd97ad52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.util.HashMap -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -28,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each * group. * @@ -38,7 +36,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * @param aggregateExpressions expressions that are computed for each group. * @param child the input data source. */ -@DeveloperApi case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 8efa471600b1b..289453753f18d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.util.Random -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -33,13 +33,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair -import org.apache.spark._ /** - * :: DeveloperApi :: * Performs a shuffle that will result in the desired `newPartitioning`. */ -@DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index abb60cf12e3a5..87bd92e00a2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -27,10 +26,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} -/** - * :: DeveloperApi :: - */ -@DeveloperApi + object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index d90cae1c4c060..a458881f40948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -32,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit * @param output The output Schema * @param child Child operator */ -@DeveloperApi case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c3c0dc441c928..78e33d9f233a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +34,6 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In } /** - * :: DeveloperApi :: * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with @@ -48,7 +46,6 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param output the output attributes of this node, which constructed in analysis phase, * and we can not change it, as the parent node bound with it already. */ -@DeveloperApi case class Generate( generator: Generator, join: Boolean, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index adb6bbc4acc5b..ba7f6287ac6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7bb4133a29059..fc9174549e642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, optimizer} -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** - * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ -@DeveloperApi class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { val analyzer = sqlContext.analyzer val optimizer = sqlContext.optimizer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 743c99a899c61..fb338b90bf79b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -21,7 +21,6 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.DataType private class ShuffledRowRDDPartition(val idx: Int) extends Partition { override val index: Int = idx diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fcb42047ffe60..8bb293ae87e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -32,7 +31,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -40,9 +39,8 @@ object SparkPlan { } /** - * :: DeveloperApi :: + * The base class for physical operators. */ -@DeveloperApi abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 55035f4bc5f2a..53c5ccf8fa37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.execution -import java.util - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.CompactBuffer -import scala.collection.mutable /** - * :: DeveloperApi :: * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) * partition. The aggregates are calculated for each row in the group. Special processing * instructions, frames, are used to calculate these aggregates. Frames are processed in the order @@ -76,7 +71,6 @@ import scala.collection.mutable * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ -@DeveloperApi case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], @@ -229,7 +223,7 @@ case class Window( // function result buffer. val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = mutable.Buffer.empty[Expression] + val unboundExpressions = scala.collection.mutable.Buffer.empty[Expression] framedWindowExprs.zipWithIndex.foreach { case ((frame, unboundFrameExpressions), index) => // Track the ordinal. @@ -529,7 +523,7 @@ private[execution] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new util.ArrayDeque[Array[AnyRef]] + private[this] val buffer = new java.util.ArrayDeque[Array[AnyRef]] /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7804b67ac2367..4db9f4ee67bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow @@ -28,10 +27,7 @@ import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -90,10 +86,6 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) } -/** - * :: DeveloperApi :: - */ -@DeveloperApi case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -125,8 +117,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } /** - * :: DeveloperApi :: * Sample the dataset. + * * @param lowerBound Lower-bound of the sampling probability (usually 0.0) * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. @@ -134,7 +126,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * @param seed the random seed * @param child the SparkPlan */ -@DeveloperApi case class Sample( lowerBound: Double, upperBound: Double, @@ -165,9 +156,8 @@ case class Sample( } /** - * :: DeveloperApi :: + * Union two plans, without a distinct. This is UNION ALL in SQL. */ -@DeveloperApi case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output @@ -179,14 +169,12 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } /** - * :: DeveloperApi :: * Take the first limit elements. Note that the implementation is different depending on whether * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, * this operator uses something similar to Spark's take method on the Spark driver. If it is not * terminal or is invoked using execute, we first take the limit on each partition, and then * repartition all the data to a single partition to compute the global limit. */ -@DeveloperApi case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: @@ -219,14 +207,12 @@ case class Limit(limit: Int, child: SparkPlan) } /** - * :: DeveloperApi :: * Take the first limit elements as defined by the sortOrder, and do projection if needed. * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, * or having a [[Project]] operator between them. * This could have been named TopK, but Spark's top operator does the opposite in ordering * so we name it TakeOrdered to avoid confusion. */ -@DeveloperApi case class TakeOrderedAndProject( limit: Int, sortOrder: Seq[SortOrder], @@ -271,13 +257,11 @@ case class TakeOrderedAndProject( } /** - * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of * the 100 new partitions will claim 10 of the current partitions. */ -@DeveloperApi case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -294,11 +278,9 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { } /** - * :: DeveloperApi :: * Returns a table with the elements from left that are not in right using * the built-in spark subtract function. */ -@DeveloperApi case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output @@ -308,11 +290,9 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } /** - * :: DeveloperApi :: * Returns the rows in left that also appear in right using the built in spark * intersection function. */ -@DeveloperApi case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = children.head.output @@ -322,12 +302,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { } /** - * :: DeveloperApi :: * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already * resolved tree. */ -@DeveloperApi case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 05ccc53830bd1..856607615ae87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution import java.util.NoSuchElementException import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, Expression, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -74,10 +73,7 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan override def argString: String = cmd.toString } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { private def keyValueOutput: Seq[Attribute] = { @@ -180,10 +176,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm * * Note that this command takes in a logical plan, runs the optimizer on the logical plan * (but do NOT actually execute it). - * - * :: DeveloperApi :: */ -@DeveloperApi case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = @@ -203,10 +196,7 @@ case class ExplainCommand( } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class CacheTableCommand( tableName: String, plan: Option[LogicalPlan], @@ -231,10 +221,6 @@ case class CacheTableCommand( } -/** - * :: DeveloperApi :: - */ -@DeveloperApi case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -246,10 +232,8 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { } /** - * :: DeveloperApi :: * Clear all cached data from the in-memory cache. */ -@DeveloperApi case object ClearCacheCommand extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -260,10 +244,7 @@ case object ClearCacheCommand extends RunnableCommand { override def output: Seq[Attribute] = Seq.empty } -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class DescribeCommand( child: SparkPlan, override val output: Seq[Attribute], @@ -286,9 +267,7 @@ case class DescribeCommand( * {{{ * SHOW TABLES [IN databaseName] * }}} - * :: DeveloperApi :: */ -@DeveloperApi case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 2e108cb814516..1d381e2eaef38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression @@ -31,13 +30,11 @@ import org.apache.spark.util.ThreadUtils import org.apache.spark.{InternalAccumulator, TaskContext} /** - * :: DeveloperApi :: * Performs an inner hash join of two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ -@DeveloperApi case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 69a8b95eaa7ec..ab81bd7b3fc04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -31,13 +30,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.{InternalAccumulator, TaskContext} /** - * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ -@DeveloperApi case class BroadcastHashOuterJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 78a8c16c62bca..c5cd6a2fd6372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,11 +25,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Build the right table's join keys into a HashSet, and iteratively go through the left * table, to find the if join keys are in the Hash set. */ -@DeveloperApi case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 28c88b1b03d02..efef8c8a8b96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -27,10 +26,7 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class BroadcastNestedLoopJoin( left: SparkPlan, right: SparkPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 2115f40702286..0243e196dbc37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -/** - * :: DeveloperApi :: - */ -@DeveloperApi + case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 66903347c88c1..15b06b1537f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution.joins -import java.util.{HashMap => JavaHashMap} - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -27,7 +24,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer -@DeveloperApi + trait HashOuterJoin { self: SparkPlan => @@ -230,8 +227,8 @@ trait HashOuterJoin { protected[this] def buildHashTable( iter: Iterator[InternalRow], numIterRows: LongSQLMetric, - keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() + keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { + val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() numIterRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index ad6362542f2ff..efa7b49410edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,11 +25,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys * for hash join. */ -@DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 18808adaac63f..bf3b05be981fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,11 +25,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Build the right table's join keys into a HashSet, and iteratively go through the left * table, to find the if join keys are in the Hash set. */ -@DeveloperApi case class LeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index fc8c9439a6f07..755986af8b95e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression @@ -26,11 +25,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join * keys. */ -@DeveloperApi case class ShuffledHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index d800c7456bdac..6b2cb9d8f6893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.joins import scala.collection.JavaConverters._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,11 +28,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * :: DeveloperApi :: * Performs a hash based outer join for two child relations by shuffling the data using * the join keys. This operator requires loading the associated partition in both side into memory. */ -@DeveloperApi case class ShuffledHashOuterJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 70a1af6a7063a..17030947b7bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -28,10 +27,8 @@ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** - * :: DeveloperApi :: * Performs an sort merge join of two child relations. */ -@DeveloperApi case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index c117dff9c8b1d..7e854e6702f77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -30,10 +29,8 @@ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} import org.apache.spark.util.collection.BitSet /** - * :: DeveloperApi :: * Performs an sort merge outer join of two child relations. */ -@DeveloperApi case class SortMergeOuterJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala index 7f2ab1765b28f..134376628ae7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi - /** - * :: DeveloperApi :: * Physical execution operators for join operations. */ package object joins { - @DeveloperApi sealed abstract class BuildSide - @DeveloperApi case object BuildRight extends BuildSide - @DeveloperApi case object BuildLeft extends BuildSide } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index 5dbe0fc5f95c7..d4e6980967e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -24,12 +24,11 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -320,10 +319,8 @@ object EvaluatePython { } /** - * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ -@DeveloperApi case class EvaluatePython( udf: PythonUDF, child: LogicalPlan, @@ -337,7 +334,6 @@ case class EvaluatePython( } /** - * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. * * Python evaluation works by sending the necessary (projected) input data via a socket to an @@ -347,7 +343,6 @@ case class EvaluatePython( * we drain the queue to find the original input row. Note that if the Python process is way too * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -@DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 855555dd1d4c4..0e601cd2cab5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -25,10 +24,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule /** - * :: DeveloperApi :: * Converts Java-object-based rows into [[UnsafeRow]]s. */ -@DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") @@ -48,10 +45,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { } /** - * :: DeveloperApi :: * Converts [[UnsafeRow]]s back into Java-object-based rows. */ -@DeveloperApi case class ConvertToSafe(child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 963e6030c14c8..a741a45f1c527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.test -import java.util - -import scala.collection.JavaConverters._ import org.apache.spark.sql.types._ /** From 1baaf2b9bd7c949a8f95cd14fc1be2a56e1139b3 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 14 Oct 2015 16:29:32 -0700 Subject: [PATCH 058/139] [SPARK-10829] [SQL] Filter combine partition key and attribute doesn't work in DataSource scan ```scala withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } ``` We expect the result to be: ``` 2,1 3,1 ``` But got ``` 1,1 2,1 3,1 ``` Author: Cheng Hao Closes #8916 from chenghao-intel/partition_filter. --- .../datasources/DataSourceStrategy.scala | 34 ++++++++++++------- .../parquet/ParquetFilterSuite.scala | 17 ++++++++++ 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 918db8e7d083e..33181fa6c065f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -62,7 +62,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) if t.partitionSpec.partitionColumns.nonEmpty => - val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray + // We divide the filter expressions into 3 parts + val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet + + // TODO this is case-sensitive + // Only prunning the partition keys + val partitionFilters = + filters.filter(_.references.map(_.name).toSet.subsetOf(partitionColumnNames)) + + // Only pushes down predicates that do not reference partition keys. + val pushedFilters = + filters.filter(_.references.map(_.name).toSet.intersect(partitionColumnNames).isEmpty) + + // Predicates with both partition keys and attributes + val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet + + val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray logInfo { val total = t.partitionSpec.partitions.length @@ -71,21 +86,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } - // Only pushes down predicates that do not reference partition columns. - val pushedFilters = { - val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet - filters.filter { f => - val referencedColumnNames = f.references.map(_.name).toSet - referencedColumnNames.intersect(partitionColumnNames).isEmpty - } - } - - buildPartitionedTableScan( + val scan = buildPartitionedTableScan( l, projects, pushedFilters, t.partitionSpec.partitionColumns, - selectedPartitions) :: Nil + selectedPartitions) + + combineFilters + .reduceLeftOption(expressions.And) + .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 45ad3fde559c0..7a23f57f40392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -297,4 +297,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-10829: Filter combine partition key and attribute doesn't work in DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), + (2 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } From 4ace4f8a9c91beb21a0077e12b75637a4560a542 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Oct 2015 17:27:27 -0700 Subject: [PATCH 059/139] [SPARK-11017] [SQL] Support ImperativeAggregates in TungstenAggregate This patch extends TungstenAggregate to support ImperativeAggregate functions. The existing TungstenAggregate operator only supported DeclarativeAggregate functions, which are defined in terms of Catalyst expressions and can be evaluated via generated projections. ImperativeAggregate functions, on the other hand, are evaluated by calling their `initialize`, `update`, `merge`, and `eval` methods. The basic strategy here is similar to how SortBasedAggregate evaluates both types of aggregate functions: use a generated projection to evaluate the expression-based declarative aggregates with dummy placeholder expressions inserted in place of the imperative aggregate function output, then invoke the imperative aggregate functions and target them against the aggregation buffer. The bulk of the diff here consists of code that was copied and adapted from SortBasedAggregate, with some key changes to handle TungstenAggregate's sort fallback path. Author: Josh Rosen Closes #9038 from JoshRosen/support-interpreted-in-tungsten-agg-final. --- .../expressions/aggregate/functions.scala | 19 +- .../expressions/aggregate/interfaces.scala | 31 +- .../aggregate/AggregationIterator.scala | 29 +- .../aggregate/TungstenAggregate.scala | 22 +- .../TungstenAggregationIterator.scala | 250 ++++++++++++---- .../spark/sql/execution/aggregate/udaf.scala | 79 +++-- .../spark/sql/execution/aggregate/utils.scala | 269 +++++++++--------- .../TungstenAggregationIteratorSuite.scala | 2 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 16 +- 9 files changed, 457 insertions(+), 260 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 8aad0b7dee054..c0bc7ec09c34a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -472,10 +472,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate { * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on -case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) - extends ImperativeAggregate { +case class HyperLogLogPlusPlus( + child: Expression, + relativeSD: Double = 0.05, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { import HyperLogLogPlusPlus._ + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + /** * HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the * algorithm will be, and the more memory it will require. The 'p' value is based on the relative @@ -546,6 +556,11 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) AttributeReference(s"MS[$i]", LongType)() } + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + /** Fill all words with zeros. */ override def initialize(buffer: MutableRow): Unit = { var word = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9ba3a9c980457..a2fab258fcac3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -150,6 +150,10 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * We need to perform similar field number arithmetic when merging multiple intermediate * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing * the input buffer). + * + * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and + * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` + * and `inputAggBufferAttributes`. */ abstract class ImperativeAggregate extends AggregateFunction2 { @@ -172,11 +176,13 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * avg(y) mutableAggBufferOffset = 2 * */ - protected var mutableAggBufferOffset: Int = 0 + protected val mutableAggBufferOffset: Int - def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = { - mutableAggBufferOffset = newMutableAggBufferOffset - } + /** + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. + */ + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate /** * The offset of this function's start buffer value in the underlying shared input aggregation @@ -203,11 +209,17 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * avg(y) inputAggBufferOffset = 3 * */ - protected var inputAggBufferOffset: Int = 0 + protected val inputAggBufferOffset: Int - def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = { - inputAggBufferOffset = newInputAggBufferOffset - } + /** + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. + */ + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate + + // Note: although all subclasses implement inputAggBufferAttributes by simply cloning + // aggBufferAttributes, that common clone code cannot be placed here in the abstract + // ImperativeAggregate class, since that will lead to initialization ordering issues. /** * Initializes the mutable aggregation buffer located in `mutableAggBuffer`. @@ -231,9 +243,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit - - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 8e0fbd109b413..99fb7a40b72e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -83,7 +83,7 @@ abstract class AggregationIterator( var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = allAggregateExpressions(i).mode match { + val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of @@ -94,24 +94,24 @@ abstract class AggregationIterator( case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. - func match { + val updatedFunc = func match { case function: ImperativeAggregate => function.withNewInputAggBufferOffset(inputBufferOffset) - case _ => + case function => function } inputBufferOffset += func.aggBufferSchema.length - func + updatedFunc } - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - funcWithBoundReferences match { + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. function.withNewMutableAggBufferOffset(mutableBufferOffset) - case _ => + case function => function } - mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length - functions(i) = funcWithBoundReferences + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset i += 1 } functions @@ -320,7 +320,7 @@ abstract class AggregationIterator( // Initializing the function used to generate the output row. protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new GenericMutableRow(resultExpressions.length) + val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) val mutableOutput = if (outputsUnsafeRows) { UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) } else { @@ -358,7 +358,8 @@ abstract class AggregationIterator( val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes // TODO: Use unsafe row. - val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = newMutableProjection( resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() @@ -366,7 +367,7 @@ abstract class AggregationIterator( (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection.target(aggregateResult)(currentBuffer) + expressionAggEvalProjection(currentBuffer) // Generate results for all imperative aggregate functions. var i = 0 while (i < allImperativeAggregateFunctions.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 7b3d072b2e067..c342940e6e757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -34,10 +35,18 @@ case class TungstenAggregate( nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + private[this] val aggregateBufferAttributes = { + (nonCompleteAggregateExpressions ++ completeAggregateExpressions) + .flatMap(_.aggregateFunction.aggBufferAttributes) + } + + require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -82,6 +91,7 @@ case class TungstenAggregate( nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, + initialInputBufferOffset, resultExpressions, newMutableProjection, child.output, @@ -138,3 +148,13 @@ case class TungstenAggregate( } } } + +object TungstenAggregate { + def supportsAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4bb95c9eb7f3e..fe708a5f71f79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.unsafe.KVIterator import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ @@ -79,6 +81,7 @@ class TungstenAggregationIterator( nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -134,19 +137,74 @@ class TungstenAggregationIterator( completeAggregateExpressions.map(_.mode).distinct.headOption } - // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate. - // If there is any functions that is an ImperativeAggregateFunction, we throw an - // IllegalStateException. - private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = { - if (!allAggregateExpressions.forall( - _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) { - throw new IllegalStateException( - "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.") + // Initialize all AggregateFunctions by binding references, if necessary, + // and setting inputBufferOffset and mutableBufferOffset. + private def initializeAllAggregateFunctions( + startingInputBufferOffset: Int): Array[AggregateFunction2] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var i = 0 + while (i < allAggregateExpressions.length) { + val func = allAggregateExpressions(i).aggregateFunction + val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length + // We need to use this mode instead of func.mode in order to handle aggregation mode switching + // when switching to sort-based aggregation: + val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 + val funcWithBoundReferences = mode match { + case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => + // We need to create BoundReferences if the function is not an + // expression-based aggregate function (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, originalInputAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + val updatedFunc = func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case function => function + } + inputBufferOffset += func.aggBufferSchema.length + updatedFunc + } + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { + case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case function => function + } + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset + i += 1 } + functions + } - allAggregateExpressions - .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - .toArray + private[this] var allAggregateFunctions: Array[AggregateFunction2] = + initializeAllAggregateFunctions(initialInputBufferOffset) + + // Positions of those imperative aggregate functions in allAggregateFunctions. + // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are imperative aggregate functions. Then + // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be + // updated when falling back to sort-based aggregation because the positions of the aggregate + // functions do not change in that case. + private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < allAggregateFunctions.length) { + allAggregateFunctions(i) match { + case agg: DeclarativeAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray } /////////////////////////////////////////////////////////////////////////// @@ -155,25 +213,31 @@ class TungstenAggregationIterator( // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values. - private[this] val initialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + // The projection used to initialize buffer values for all expression-based aggregates. + // Note that this projection does not need to be updated when switching to sort-based aggregation + // because the schema of empty aggregation buffers does not change in that case. + private[this] val expressionAggInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.initialValues + // For the positions corresponding to imperative aggregate functions, we'll use special + // no-op expressions which are ignored during projection code-generation. + case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) + } newMutableProjection(initExpressions, Nil)() } // Creates a new aggregation buffer and initializes buffer values. - // This functions should be only called at most three times (when we create the hash map, + // This function should be only called at most three times (when we create the hash map, // when we switch to sort-based aggregation, and when we create the re-used buffer for // sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initialProjection.target(buffer)(EmptyRow) + val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) + .apply(new GenericMutableRow(bufferSchema.length)) + // Initialize declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initialize imperative aggregates' buffer values + allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } @@ -187,72 +251,124 @@ class TungstenAggregationIterator( aggregationMode match { // Partial-only case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) - val updateProjection = + val updateExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - updateProjection.target(currentBuffer) - updateProjection(joinedRow(currentBuffer, row)) + expressionAggUpdateProjection.target(currentBuffer) + // Process all expression-based aggregate functions. + expressionAggUpdateProjection(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // PartialMerge-only or Final-only case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) - val mergeProjection = + val mergeExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + // This projection is used to merge buffer values for all expression-based aggregates. + val expressionAggMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - mergeProjection.target(currentBuffer) - mergeProjection(joinedRow(currentBuffer, row)) + // Process all expression-based aggregate functions. + expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Final-Complete case (Some(Final), Some(Complete)) => - val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } + val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = + nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } val completeOffsetExpressions = Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val mergeExpressions = - nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + nonCompleteAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() // We do not touch buffer values of aggregate functions with the Final mode. val finalOffsetExpressions = Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update the given currentBuffer. + // For all aggregate functions with mode Complete, update buffers. completeUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } // For all aggregate functions with mode Final, merge buffer values in row to // currentBuffer. finalMergeProjection.target(currentBuffer)(input) + i = 0 + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All imperative aggregate functions with mode Complete. + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val updateExpressions = - completeAggregateFunctions.flatMap(_.updateExpressions) - val completeUpdateProjection = + val updateExpressions = completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - completeUpdateProjection.target(currentBuffer) - // For all aggregate functions with mode Complete, update the given currentBuffer. - completeUpdateProjection(joinedRow(currentBuffer, row)) + // For all aggregate functions with mode Complete, update buffers. + completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // Grouping only. @@ -288,17 +404,30 @@ class TungstenAggregationIterator( val joinedRow = new JoinedRow() val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - // case agg: AggregateFunction2 => Literal.create(null, agg.dataType) + case agg: AggregateFunction2 => NoOp } - val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() // These are the attributes of the row produced by `expressionAggEvalProjection` val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) + val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { // Generate results for all expression-based aggregate functions. - val aggregateResult = expressionAggEvalProjection.apply(currentBuffer) + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 + } resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } @@ -481,10 +610,27 @@ class TungstenAggregationIterator( // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. // We need to project the aggregation buffer part from an input row. val buffer = createNewAggregationBuffer() - // The originalInputAttributes are using cloneBufferAttributes. So, we need to use - // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to + // extract the aggregation buffer. In practice, however, we extract it positionally by relying + // on it being present at the end of the row. The reason for this relates to how the different + // aggregates handle input binding. + // + // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers, + // so its correctness does not rely on attribute bindings. When we fall back to sort-based + // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset) + // need to be updated and any internal state in the aggregate functions themselves must be + // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset + // this state and update the offsets. + // + // The updated ImperativeAggregate will have different attribute ids for its + // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual + // ImperativeAggregate evaluation, but it means that + // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the + // attributes in `originalInputAttributes`, which is why we can't use those attributes here. + // + // For more details, see the discussion on PR #9038. val bufferExtractor = newMutableProjection( - allAggregateFunctions.flatMap(_.inputAggBufferAttributes), + originalInputAttributes.drop(initialInputBufferOffset), originalInputAttributes)() bufferExtractor.target(buffer) @@ -511,8 +657,10 @@ class TungstenAggregationIterator( } aggregationMode = newAggregationMode + allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) + // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use cloneBufferAttributes. + // will just aggregation buffer. At here, we use inputAggBufferAttributes. val newInputAttributes: Seq[Attribute] = allAggregateFunctions.flatMap(_.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fd02be1225f27..d2f56e0fc14a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -321,9 +321,17 @@ private[sql] class InputAggregationBuffer private[sql] ( */ private[sql] case class ScalaUDAF( children: Seq[Expression], - udaf: UserDefinedAggregateFunction) + udaf: UserDefinedAggregateFunction, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with Logging { + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + require( children.length == udaf.inputSchema.length, s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + @@ -341,6 +349,11 @@ private[sql] case class ScalaUDAF( override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => @@ -382,51 +395,33 @@ private[sql] case class ScalaUDAF( } // This buffer is only used at executor side. - private[this] var inputAggregateBuffer: InputAggregationBuffer = null - - // This buffer is only used at executor side. - private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null + private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + inputAggBufferOffset, + null) + } // This buffer is only used at executor side. - private[this] var evalAggregateBuffer: InputAggregationBuffer = null - - /** - * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of - * `inputAggregateBuffer` based on this new inputBufferOffset. - */ - override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = { - super.withNewInputAggBufferOffset(newInputBufferOffset) - // inputBufferOffset has been updated. - inputAggregateBuffer = - new InputAggregationBuffer( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - inputAggBufferOffset, - null) + private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } - /** - * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of - * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. - */ - override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = { - super.withNewMutableAggBufferOffset(newMutableBufferOffset) - // mutableBufferOffset has been updated. - mutableAggregateBuffer = - new MutableAggregationBufferImpl( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableAggBufferOffset, - null) - evalAggregateBuffer = - new InputAggregationBuffer( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableAggBufferOffset, - null) + // This buffer is only used at executor side. + private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } override def initialize(buffer: MutableRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index cf6e7ed0d337f..eaafd83158a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -19,21 +19,12 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkPlan /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - def supportsTungstenAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { - val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupingExpressions) - } def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], @@ -70,8 +61,7 @@ object Utils { // Check if we can use TungstenAggregate. val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) && - supportsTungstenAggregate( + TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) @@ -94,6 +84,7 @@ object Utils { nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) } else { @@ -125,6 +116,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, + initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) } else { @@ -154,143 +146,150 @@ object Utils { val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall( - _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) && - supportsTungstenAggregate( + TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // 1. Create an Aggregate Operator for partial aggregations. - val groupingAttributes = groupingExpressions.map(_.toAttribute) - - // It is safe to call head at here since functionsWithDistinct has at least one - // AggregateExpression2. - val distinctColumnExpressions = - functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { - case ne: NamedExpression => ne -> ne - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one + // DISTINCT aggregate function, all of those functions will have the same column expression. + // For example, it would be valid for functionsWithDistinct to be + // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is + // disallowed because those two distinct aggregates have different column expressions. + val distinctColumnExpression: Expression = { + val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children + assert(allDistinctColumnExpressions.length == 1) + allDistinctColumnExpressions.head + } + val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() } - val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions.map(_._2) - val partialAggregateResult = + // 1. Create an Aggregate Operator for partial aggregations. + val partialAggregate: SparkPlan = { + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression + val partialAggregateResult = groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) + Seq(distinctColumnAttribute) ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = + val partialMergeAggregate: SparkPlan = { + val partialMergeAggregateExpressions = + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialMergeAggregateResult = groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialMergeAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + Seq(distinctColumnAttribute) ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } } - // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + // 3. Create an Aggregate Operator for the final aggregation. + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if distinctColumnExpressionMap.contains(expr) => - distinctColumnExpressionMap(expr).toAttribute - }.asInstanceOf[AggregateFunction2] - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, true) + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, true) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if expr == distinctColumnExpression => distinctColumnAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) - }.unzip - - val finalAndCompleteAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) + val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) + (rewrittenAggregateExpression, aggregateFunctionAttribute) + }.unzip + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = resultExpressions, + child = partialMergeAggregate) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = resultExpressions, + child = partialMergeAggregate) + } } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index ed974b3a53d41..0cc4988ff681c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, - Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 18bbdb9908142..a2ebf6552fd06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -553,10 +553,16 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, children: Seq[Expression], - isUDAFBridgeRequired: Boolean = false) + isUDAFBridgeRequired: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with HiveInspectors { - def this() = this(null, null) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) @transient private lazy val resolver = @@ -614,7 +620,11 @@ private[hive] case class HiveUDAFFunction( buffer = function.getNewAggregationBuffer } - override def aggBufferAttributes: Seq[AttributeReference] = Nil + override val aggBufferAttributes: Seq[AttributeReference] = Nil + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = Nil // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. From 9808052b5adfed7dafd6c1b3971b998e45b2799a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 14 Oct 2015 20:56:08 -0700 Subject: [PATCH 060/139] [SPARK-11076] [SQL] Add decimal support for floor and ceil Actually all of the `UnaryMathExpression` doens't support the Decimal, will create follow ups for supporing it. This is the first PR which will be good to review the approach I am taking. Author: Cheng Hao Closes #9086 from chenghao-intel/ceiling. --- .../expressions/mathExpressions.scala | 48 +++++++++++++++---- .../org/apache/spark/sql/types/Decimal.scala | 32 +++++++++++-- .../expressions/LiteralGenerator.scala | 14 +++++- .../expressions/MathFunctionsSuite.scala | 10 ++++ 4 files changed, 91 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a8164e9e29ec6..28f616fbb9ca5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String) abstract class UnaryMathExpression(val f: Double => Double, name: String) extends UnaryExpression with Serializable with ImplicitCastInputTypes { - override def inputTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -153,13 +153,28 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { - override def dataType: DataType = LongType - protected override def nullSafeEval(input: Any): Any = { - f(input.asInstanceOf[Double]).toLong + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.ceil()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } } } @@ -205,13 +220,28 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { - override def dataType: DataType = LongType - protected override def nullSafeEval(input: Any): Any = { - f(input.asInstanceOf[Double]).toLong + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.floor()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c11dab35cdf6f..c7a1a2e7469ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -107,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP) require( decimalVal.precision <= precision, s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") @@ -198,6 +198,16 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + changePrecision(precision, scale, ROUND_HALF_UP) + } + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + private[sql] def changePrecision(precision: Int, scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -231,7 +241,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, roundMode) if (newVal.precision > precision) { return false } @@ -309,10 +319,26 @@ final class Decimal extends Ordered[Decimal] with Serializable { } def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + + def floor: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) + value + } + + def ceil: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) + value + } } object Decimal { - private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP + val ROUND_CEILING = BigDecimal.RoundingMode.CEILING + val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index ee6d25157fc08..d9c91415e249d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -78,7 +78,18 @@ object LiteralGenerator { Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) } yield Literal.create(f, DoubleType) - // TODO: decimal type + // TODO cache the generated data + def decimalLiteralGen(precision: Int, scale: Int): Gen[Literal] = { + assert(scale >= 0) + assert(precision >= scale) + Arbitrary.arbBigInt.arbitrary.map { s => + val a = (s % BigInt(10).pow(precision - scale)).toString() + val b = (s % BigInt(10).pow(scale)).abs.toString() + Literal.create( + Decimal(BigDecimal(s"$a.$b"), precision, scale), + DecimalType(precision, scale)) + } + } lazy val stringLiteralGen: Gen[Literal] = for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) @@ -122,6 +133,7 @@ object LiteralGenerator { case StringType => stringLiteralGen case BinaryType => binaryLiteralGen case CalendarIntervalType => calendarIntervalLiterGen + case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale) case dt => throw new IllegalArgumentException(s"not supported type $dt") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1b2a9163a3d09..88ed9fdd6465f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -246,11 +246,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) + + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) } test("floor") { testUnary(Floor, (d: Double) => math.floor(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) + + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) } test("factorial") { From 0f62c2282bb30cb4fb6eea9d28b198d557a79b22 Mon Sep 17 00:00:00 2001 From: Adam Lewandowski Date: Thu, 15 Oct 2015 09:45:54 -0700 Subject: [PATCH 061/139] [SPARK-11093] [CORE] ChildFirstURLClassLoader#getResources should return all found resources, not just those in the child classloader Author: Adam Lewandowski Closes #9106 from alewando/childFirstFix. --- .../spark/util/MutableURLClassLoader.scala | 13 +++--- .../util/MutableURLClassLoaderSuite.scala | 40 ++++++++++++++++++- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index a1c33212cdb2b..945217203be72 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -21,6 +21,8 @@ import java.net.{URLClassLoader, URL} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. */ @@ -82,14 +84,9 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa } override def getResources(name: String): Enumeration[URL] = { - val urls = super.findResources(name) - val res = - if (urls != null && urls.hasMoreElements()) { - urls - } else { - parentClassLoader.getResources(name) - } - res + val childUrls = super.findResources(name).asScala + val parentUrls = parentClassLoader.getResources(name).asScala + (childUrls ++ parentUrls).asJavaEnumeration } override def addURL(url: URL) { diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index d3d464e84ffd7..8b53d4f14a6a4 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -19,9 +19,14 @@ package org.apache.spark.util import java.net.URLClassLoader +import scala.collection.JavaConverters._ + +import org.scalatest.Matchers +import org.scalatest.Matchers._ + import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} -class MutableURLClassLoaderSuite extends SparkFunSuite { +class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val urls2 = List(TestUtils.createJarWithClasses( classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), @@ -32,6 +37,12 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { toStringValue = "1", classpathUrls = urls2)).toArray + val fileUrlsChild = List(TestUtils.createJarWithFiles(Map( + "resource1" -> "resource1Contents-child", + "resource2" -> "resource2Contents"))).toArray + val fileUrlsParent = List(TestUtils.createJarWithFiles(Map( + "resource1" -> "resource1Contents-parent"))).toArray + test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) @@ -68,6 +79,33 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { } } + test("default JDK classloader get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new URLClassLoader(fileUrlsChild, parentLoader) + assert(classLoader.getResources("resource1").asScala.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + } + + test("parent first get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new MutableURLClassLoader(fileUrlsChild, parentLoader) + assert(classLoader.getResources("resource1").asScala.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + } + + test("child first get resources") { + val parentLoader = new URLClassLoader(fileUrlsParent, null) + val classLoader = new ChildFirstURLClassLoader(fileUrlsChild, parentLoader) + + val res1 = classLoader.getResources("resource1").asScala.toList + assert(res1.size === 2) + assert(classLoader.getResources("resource2").asScala.size === 1) + + res1.map(scala.io.Source.fromURL(_).mkString) should contain inOrderOnly + ("resource1Contents-child", "resource1Contents-parent") + } + + test("driver sets context class loader in local mode") { // Test the case where the driver program sets a context classloader and then runs a job // in local mode. This is what happens when ./spark-submit is called with "local" as the From aec4400beffc569c13cceea2d0c481dfa3f34175 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 15 Oct 2015 09:49:19 -0700 Subject: [PATCH 062/139] =?UTF-8?q?[SPARK-11099]=20[SPARK=20SHELL]=20[SPAR?= =?UTF-8?q?K=20SUBMIT]=20Default=20conf=20property=20file=20i=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Please help review it. Thanks Author: Jeff Zhang Closes #9114 from zjffdu/SPARK-11099. --- .../launcher/AbstractCommandBuilder.java | 14 ++++------ .../SparkSubmitCommandBuilderSuite.java | 28 +++++++++++++------ .../src/test/resources/spark-defaults.conf | 21 ++++++++++++++ 3 files changed, 45 insertions(+), 18 deletions(-) create mode 100644 launcher/src/test/resources/spark-defaults.conf diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index cf3729b7febc3..3ee6bd92e47fc 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -272,15 +272,11 @@ void setPropertiesFile(String path) { Map getEffectiveConfig() throws IOException { if (effectiveConfig == null) { - if (propertiesFile == null) { - effectiveConfig = conf; - } else { - effectiveConfig = new HashMap<>(conf); - Properties p = loadPropertiesFile(); - for (String key : p.stringPropertyNames()) { - if (!effectiveConfig.containsKey(key)) { - effectiveConfig.put(key, p.getProperty(key)); - } + effectiveConfig = new HashMap<>(conf); + Properties p = loadPropertiesFile(); + for (String key : p.stringPropertyNames()) { + if (!effectiveConfig.containsKey(key)) { + effectiveConfig.put(key, p.getProperty(key)); } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index d5397b0685046..6aad47adbcc82 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -48,12 +48,14 @@ public static void cleanUp() throws Exception { @Test public void testDriverCmdBuilder() throws Exception { - testCmdBuilder(true); + testCmdBuilder(true, true); + testCmdBuilder(true, false); } @Test public void testClusterCmdBuilder() throws Exception { - testCmdBuilder(false); + testCmdBuilder(false, true); + testCmdBuilder(false, false); } @Test @@ -149,7 +151,7 @@ public void testPySparkFallback() throws Exception { assertEquals("arg1", cmd.get(cmd.size() - 1)); } - private void testCmdBuilder(boolean isDriver) throws Exception { + private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = @@ -161,14 +163,20 @@ private void testCmdBuilder(boolean isDriver) throws Exception { launcher.appResource = "/foo"; launcher.appName = "MyApp"; launcher.mainClass = "my.Class"; - launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); launcher.appArgs.add("foo"); launcher.appArgs.add("bar"); - launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); - launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); launcher.conf.put("spark.foo", "foo"); + // either set the property through "--conf" or through default property file + if (!useDefaultPropertyFile) { + launcher.setPropertiesFile(dummyPropsFile.getAbsolutePath()); + launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); + } else { + launcher.childEnv.put("SPARK_CONF_DIR", System.getProperty("spark.test.home") + + "/launcher/src/test/resources"); + } Map env = new HashMap(); List cmd = launcher.buildCommand(env); @@ -216,7 +224,9 @@ private void testCmdBuilder(boolean isDriver) throws Exception { } // Checks below are the same for both driver and non-driver mode. - assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + if (!useDefaultPropertyFile) { + assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + } assertEquals("yarn", findArgValue(cmd, parser.MASTER)); assertEquals(deployMode, findArgValue(cmd, parser.DEPLOY_MODE)); assertEquals("my.Class", findArgValue(cmd, parser.CLASS)); diff --git a/launcher/src/test/resources/spark-defaults.conf b/launcher/src/test/resources/spark-defaults.conf new file mode 100644 index 0000000000000..239fc57883e98 --- /dev/null +++ b/launcher/src/test/resources/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +spark.driver.memory=1g +spark.driver.extraClassPath=/driver +spark.driver.extraJavaOptions=-Ddriver -XX:MaxPermSize=256m +spark.driver.extraLibraryPath=/native \ No newline at end of file From 523adc24a683930304f408d477607edfe9de7b76 Mon Sep 17 00:00:00 2001 From: shellberg Date: Thu, 15 Oct 2015 18:07:10 +0100 Subject: [PATCH 063/139] [SPARK-11066] Update DAGScheduler's "misbehaved ResultHandler" Restrict tasks (of job) to only 1 to ensure that the causing Exception asserted for job failure is the deliberately thrown DAGSchedulerSuiteDummyException intended, not an UnsupportedOperationException from any second/subsequent tasks that can propagate from a race condition during code execution. Author: shellberg Closes #9076 from shellberg/shellberg-DAGSchedulerSuite-misbehavedResultHandlerTest-patch-1. --- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 697c195e4ad1f..5b01ddb298c39 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1375,18 +1375,27 @@ class DAGSchedulerSuite assert(sc.parallelize(1 to 10, 2).count() === 10) } + /** + * The job will be failed on first task throwing a DAGSchedulerSuiteDummyException. + * Any subsequent task WILL throw a legitimate java.lang.UnsupportedOperationException. + * If multiple tasks, there exists a race condition between the SparkDriverExecutionExceptions + * and their differing causes as to which will represent result for job... + */ test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { val e = intercept[SparkDriverExecutionException] { + // Number of parallelized partitions implies number of tasks of job val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0, 1), + // For a robust test assertion, limit number of job tasks to 1; that is, + // if multiple RDD partitions, use id of any one partition, say, first partition id=0 + Seq(0), (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run commands + // Make sure we can still run commands on our SparkContext assert(sc.parallelize(1 to 10, 2).count() === 10) } From d45a0d3ca23df86cf0a95508ccc3b4b98f1b611c Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 15 Oct 2015 10:36:54 -0700 Subject: [PATCH 064/139] [SPARK-11047] Internal accumulators miss the internal flag when replaying events in the history server Internal accumulators don't write the internal flag to event log. So on the history server Web UI, all accumulators are not internal. This causes incorrect peak execution memory and unwanted accumulator table displayed on the stage page. To fix it, I add the "internal" property of AccumulableInfo when writing the event log. Author: Carson Wang Closes #9061 from carsonwang/accumulableBug. --- .../spark/scheduler/AccumulableInfo.scala | 9 ++ .../org/apache/spark/util/JsonProtocol.scala | 6 +- .../apache/spark/util/JsonProtocolSuite.scala | 96 +++++++++++++------ 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index b6bff64ee368e..146cfb9ba8037 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -46,6 +46,15 @@ class AccumulableInfo private[spark] ( } object AccumulableInfo { + def apply( + id: Long, + name: String, + update: Option[String], + value: String, + internal: Boolean): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal) + } + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value, internal = false) } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40729fa5a4ffe..a06dc6f709d33 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -282,7 +282,8 @@ private[spark] object JsonProtocol { ("ID" -> accumulableInfo.id) ~ ("Name" -> accumulableInfo.name) ~ ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~ - ("Value" -> accumulableInfo.value) + ("Value" -> accumulableInfo.value) ~ + ("Internal" -> accumulableInfo.internal) } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { @@ -696,7 +697,8 @@ private[spark] object JsonProtocol { val name = (json \ "Name").extract[String] val update = Utils.jsonOption(json \ "Update").map(_.extract[String]) val value = (json \ "Value").extract[String] - AccumulableInfo(id, name, update, value) + val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) + AccumulableInfo(id, name, update, value, internal) } def taskMetricsFromJson(json: JValue): TaskMetrics = { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a24bf2931cca0..f9572921f43cb 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -364,6 +364,15 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) } + test("AccumulableInfo backward compatibility") { + // "Internal" property of AccumulableInfo were added after 1.5.1. + val accumulableInfo = makeAccumulableInfo(1) + val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + .removeField({ _._1 == "Internal" }) + val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) + assert(false === oldInfo.internal) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -723,15 +732,15 @@ class JsonProtocolSuite extends SparkFunSuite { val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative) val (acc1, acc2, acc3) = - (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3)) + (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true)) taskInfo.accumulables += acc1 taskInfo.accumulables += acc2 taskInfo.accumulables += acc3 taskInfo } - private def makeAccumulableInfo(id: Int): AccumulableInfo = - AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id) + private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo = + AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is @@ -812,13 +821,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -866,13 +877,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | } @@ -902,19 +915,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | } @@ -942,19 +958,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | } @@ -988,19 +1007,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1074,19 +1096,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1157,19 +1182,22 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", - | "Value": "val3" + | "Value": "val3", + | "Internal": true | } | ] | }, @@ -1251,13 +1279,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1309,13 +1339,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1384,13 +1416,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | }, @@ -1476,13 +1510,15 @@ class JsonProtocolSuite extends SparkFunSuite { | "ID": 2, | "Name": " Accumulable 2", | "Update": "delta2", - | "Value": "val2" + | "Value": "val2", + | "Internal": false | }, | { | "ID": 1, | "Name": " Accumulable 1", | "Update": "delta1", - | "Value": "val1" + | "Value": "val1", + | "Internal": false | } | ] | } From b591de7c07ba8e71092f71e34001520bec995a8a Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Thu, 15 Oct 2015 12:45:37 -0700 Subject: [PATCH 065/139] [SPARK-11039][Documentation][Web UI] Document additional ui configurations Add documentation for configuration: - spark.sql.ui.retainedExecutions - spark.streaming.ui.retainedBatches Author: Nick Pritchard Closes #9052 from pnpritchard/SPARK-11039. --- docs/configuration.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 771d93be04b06..46d92ceb762d6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -554,6 +554,20 @@ Apart from these, the following properties are also available, and may be useful How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + + + + + + + + +
Input
spark.sql.ui.retainedExecutions1000 + How many finished executions the Spark UI and status APIs remember before garbage collecting. +
spark.streaming.ui.retainedBatches1000 + How many finished batches the Spark UI and status APIs remember before garbage collecting. +
#### Compression and Serialization From a5719804c5ed99ce36bd0dd230ab8b3b7a3b92e3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Oct 2015 14:46:40 -0700 Subject: [PATCH 066/139] [SPARK-11071] [LAUNCHER] Fix flakiness in LauncherServerSuite::timeout. The test could fail depending on scheduling of the various threads involved; the change removes some sources of races, while making the test a little more resilient by trying a few times before giving up. Author: Marcelo Vanzin Closes #9079 from vanzin/SPARK-11071. --- .../apache/spark/launcher/LauncherServer.java | 9 ++++- .../spark/launcher/LauncherServerSuite.java | 35 ++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index c5fd40816d62f..d099ee9aa9dae 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -242,7 +242,14 @@ public void run() { synchronized (clients) { clients.add(clientConnection); } - timeoutTimer.schedule(timeout, getConnectionTimeout()); + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, getConnectionTimeout()); + } else { + timeout.run(); + } } } } catch (IOException ioe) { diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 27cd1061a15b3..dc8fbb58d880b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -121,12 +121,12 @@ private void wakeUp() { @Test public void testTimeout() throws Exception { - final long TEST_TIMEOUT = 10L; - ChildProcAppHandle handle = null; TestClient client = null; try { - SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, String.valueOf(TEST_TIMEOUT)); + // LauncherServer will immediately close the server-side socket when the timeout is set + // to 0. + SparkLauncher.setConfig(SparkLauncher.CHILD_CONNECTION_TIMEOUT, "0"); handle = LauncherServer.newAppHandle(); @@ -134,12 +134,29 @@ public void testTimeout() throws Exception { LauncherServer.getServerInstance().getPort()); client = new TestClient(s); - Thread.sleep(TEST_TIMEOUT * 10); - try { - client.send(new Hello(handle.getSecret(), "1.4.0")); - fail("Expected exception caused by connection timeout."); - } catch (IllegalStateException e) { - // Expected. + // Try a few times since the client-side socket may not reflect the server-side close + // immediately. + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { + try { + if (!helloSent) { + client.send(new Hello(handle.getSecret(), "1.4.0")); + helloSent = true; + } else { + client.send(new SetAppId("appId")); + } + fail("Expected exception caused by connection timeout."); + } catch (IllegalStateException | IOException e) { + // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } + } } } finally { SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); From 723aa75a9d566c698aa49597f4f655396fef77bd Mon Sep 17 00:00:00 2001 From: Britta Weber Date: Thu, 15 Oct 2015 14:47:11 -0700 Subject: [PATCH 067/139] fix typo bellow -> below Author: Britta Weber Closes #9136 from brwe/typo-bellow. --- docs/mllib-collaborative-filtering.md | 2 +- docs/mllib-linear-methods.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index b3fd51dca5c90..1ad52123c74aa 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -119,7 +119,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A self-contained application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index a3e1620c778ff..0c76e6e999465 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -230,7 +230,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A self-contained application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: Refer to the [`SVMWithSGD` Java docs](api/java/org/apache/spark/mllib/classification/SVMWithSGD.html) and [`SVMModel` Java docs](api/java/org/apache/spark/mllib/classification/SVMModel.html) for details on the API. @@ -612,7 +612,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. The corresponding Java example to -the Scala snippet provided, is presented bellow: +the Scala snippet provided, is presented below: Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API. From 2d000124b72d0ff9e3ecefa03923405642516c4c Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Thu, 15 Oct 2015 14:48:01 -0700 Subject: [PATCH 068/139] [SPARK-10515] When killing executor, the pending replacement executors should not be lost If the heartbeat receiver kills executors (and new ones are not registered to replace them), the idle timeout for the old executors will be lost (and then change a total number of executors requested by Driver), So new ones will be not to asked to replace them. For example, executorsPendingToRemove=Set(1), and executor 2 is idle timeout before a new executor is asked to replace executor 1. Then driver kill executor 2, and sending RequestExecutors to AM. But executorsPendingToRemove=Set(1,2), So AM doesn't allocate a executor to replace 1. see: https://github.com/apache/spark/pull/8668 Author: KaiXinXiaoLei Author: huleilei Closes #8945 from KaiXinXiaoLei/pendingexecutor. --- .../CoarseGrainedSchedulerBackend.scala | 2 ++ .../StandaloneDynamicAllocationSuite.scala | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 18771f79b44bb..55a564b5c8eac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -438,6 +438,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (!replace) { doRequestTotalExecutors( numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } else { + numPendingExecutors += knownExecutors.size } doKillExecutors(executorsToKill) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 2e2fa22eb4772..d145e78834b1b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -369,6 +369,41 @@ class StandaloneDynamicAllocationSuite assert(apps.head.getExecutorLimit === 1) } + test("the pending replacement executors should not be lost (SPARK-10515)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + val executors = getExecutorIds(sc) + assert(executors.size === 2) + // kill executor 1, and replace it + assert(sc.killAndReplaceExecutor(executors.head)) + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.executors.size === 2) + } + + var apps = getApplications() + // kill executor 1 + assert(sc.killExecutor(executors.head)) + apps = getApplications() + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === 2) + // kill executor 2 + assert(sc.killExecutor(executors(1))) + apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) + } + // =============================== // | Utility methods for testing | // =============================== From 3b364ff0a4f38c2b8023429a55623de32be5f329 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 15 Oct 2015 14:50:01 -0700 Subject: [PATCH 069/139] [SPARK-11078] Ensure spilling tests actually spill #9084 uncovered that many tests that test spilling don't actually spill. This is a follow-up patch to fix that to ensure our unit tests actually catch potential bugs in spilling. The size of this patch is inflated by the refactoring of `ExternalSorterSuite`, which had a lot of duplicate code and logic. Author: Andrew Or Closes #9124 from andrewor14/spilling-tests. --- .../scala/org/apache/spark/TestUtils.scala | 51 + .../spark/shuffle/ShuffleMemoryManager.scala | 6 +- .../collection/ExternalAppendOnlyMap.scala | 6 + .../spark/util/collection/Spillable.scala | 37 +- .../org/apache/spark/DistributedSuite.scala | 39 +- .../ExternalAppendOnlyMapSuite.scala | 103 ++- .../util/collection/ExternalSorterSuite.scala | 871 ++++++++---------- .../execution/TestShuffleMemoryManager.scala | 2 + 8 files changed, 534 insertions(+), 581 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 888763a3e8ebf..acfe751f6c746 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -24,10 +24,14 @@ import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils /** @@ -154,4 +158,51 @@ private[spark] object TestUtils { " @Override public String toString() { return \"" + toStringValue + "\"; }}") createCompiledClass(className, destDir, sourceFile, classpathUrls) } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs spilled. + */ + def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs + * did not spill. + */ + def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + } + +} + + +/** + * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + */ +private class SpillListener extends SparkListener { + private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] + private val spilledStageIds = new mutable.HashSet[Int] + + def numSpilledStages: Int = spilledStageIds.size + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + stageIdToTaskMetrics.getOrElseUpdate( + taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics + } + + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + val stageId = stageComplete.stageInfo.stageId + val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten + val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 + if (spilled) { + spilledStageIds += stageId + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index aaf543ce9232a..9bd18da47f1a2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -139,8 +139,10 @@ class ShuffleMemoryManager protected ( throw new SparkException( s"Internal error: release called on $numBytes bytes but task only has $curMem") } - taskMemory(taskAttemptId) -= numBytes - memoryManager.releaseExecutionMemory(numBytes) + if (taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) -= numBytes + memoryManager.releaseExecutionMemory(numBytes) + } memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 6a96b5dc12684..cfa58f5ef408a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -95,6 +95,12 @@ class ExternalAppendOnlyMap[K, V, C]( private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + /** + * Number of files this map has spilled so far. + * Exposed for testing. + */ + private[collection] def numSpills: Int = spilledMaps.size + /** * Insert the given key and value into the map. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..d2a68ca7a3b4c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -43,10 +43,15 @@ private[spark] trait Spillable[C] extends Logging { private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing + // For testing only private[this] val initialMemoryThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + // Force this collection to spill when there are this many elements in memory + // For testing only + private[this] val numElementsForceSpillThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold @@ -69,27 +74,27 @@ private[spark] trait Spillable[C] extends Logging { * @return true if `collection` was spilled to disk; false otherwise */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + var shouldSpill = false if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - _spillCount += 1 - logSpillage(currentMemory) - - spill(collection) - - _elementsRead = 0 - // Keep track of spills, and release memory - _memoryBytesSpilled += currentMemory - releaseMemoryForThisThread() - return true - } + // If we were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold), spill the current collection + shouldSpill = currentMemory >= myMemoryThreshold + } + shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold + // Actually spill + if (shouldSpill) { + _spillCount += 1 + logSpillage(currentMemory) + spill(collection) + _elementsRead = 0 + _memoryBytesSpilled += currentMemory + releaseMemoryForThisThread() } - false + shouldSpill } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 34a4bb968e732..1c3f2bc315ddc 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -203,22 +203,35 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("compute without caching when no partitions fit in memory") { - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory - val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) + val size = 10000 + val conf = new SparkConf() + .set("spark.storage.unrollMemoryThreshold", "1024") + .set("spark.testing.memory", (size / 2).toString) + sc = new SparkContext(clusterUrl, "test", conf) + val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY) + assert(data.count() === size) + assert(data.count() === size) + assert(data.count() === size) + // ensure only a subset of partitions were cached + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}") } test("compute when only some partitions fit in memory") { - sc = new SparkContext(clusterUrl, "test", new SparkConf) - // TODO: verify that only a subset of partitions fit in memory (SPARK-11078) - val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) + val size = 10000 + val numPartitions = 10 + val conf = new SparkConf() + .set("spark.storage.unrollMemoryThreshold", "1024") + .set("spark.testing.memory", (size * numPartitions).toString) + sc = new SparkContext(clusterUrl, "test", conf) + val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY) + assert(data.count() === size) + assert(data.count() === size) + assert(data.count() === size) + // ensure only a subset of partitions were cached + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + assert(rddBlocks.size > 0, "no RDD blocks found") + assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions") } test("passing environment variables to cluster") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 0a03c32c647ae..5cb506ea2164e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec -// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078) class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { + import TestUtils.{assertNotSpilled, assertSpilled} + private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i @@ -244,54 +245,53 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { * If a compression codec is provided, use it. Otherwise, do not compress spills. */ private def testSimpleSpilling(codec: Option[String] = None): Unit = { + val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home + conf.set("spark.shuffle.manager", "hash") // avoid using external sorter + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - // reduceByKey - should spill ~8 times - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max).collect() - assert(resultA.length === 50000) - resultA.foreach { case (k, v) => - assert(v === k * 2 + 1, s"Value for $k was wrong: expected ${k * 2 + 1}, got $v") + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) }.reduceByKey(math.max).collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } } - // groupByKey - should spill ~17 times - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey().collect() - assert(resultB.length === 25000) - resultB.foreach { case (i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - assert(seq.toSet === expected, - s"Value for $i was wrong: expected $expected, got ${seq.toSet}") + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size).map { i => (i / 2, i) }.groupByKey().collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } } - // cogroup - should spill ~7 times - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2).collect() - assert(resultC.length === 10000) - resultC.foreach { case (i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet === Set[Int](0)) - assert(seq2.toSet === Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet === Set[Int](1)) - assert(seq2.toSet === Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet === Set[Int](5000)) - assert(seq2.toSet === Set[Int]()) - case 9999 => - assert(seq1.toSet === Set[Int](9999)) - assert(seq2.toSet === Set[Int]()) - case _ => + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") } } + sc.stop() } test("spilling with hash collisions") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] @@ -315,11 +315,12 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(w1.hashCode === w2.hashCode) } - map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i))) + map.insertAll((1 to size).iterator.map(_.toString).map(i => (i, i))) collisionPairs.foreach { case (w1, w2) => map.insert(w1, w2) map.insert(w2, w1) } + assert(map.numSpills > 0, "map did not spill") // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -334,22 +335,25 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2.equals(expectedValue)) count += 1 } - assert(count === 100000 + collisionPairs.size * 2) + assert(count === size + collisionPairs.size * 2) sc.stop() } test("spilling with many hash collisions") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). for (i <- 1 to 10) { - for (j <- 1 to 10000) { + for (j <- 1 to size) { map.insert(FixedHashObject(j, j % 2), 1) } } + assert(map.numSpills > 0, "map did not spill") val it = map.iterator var count = 0 @@ -358,17 +362,20 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2 === 10) count += 1 } - assert(count === 10000) + assert(count === size) sc.stop() } test("spilling with hash collisions using the Int.MaxValue key") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] - (1 to 100000).foreach { i => map.insert(i, i) } + (1 to size).foreach { i => map.insert(i, i) } map.insert(Int.MaxValue, Int.MaxValue) + assert(map.numSpills > 0, "map did not spill") val it = map.iterator while (it.hasNext) { @@ -379,14 +386,17 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } test("spilling with null keys and values") { + val size = 1000 val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] - map.insertAll((1 to 100000).iterator.map(i => (i, i))) + map.insertAll((1 to size).iterator.map(i => (i, i))) map.insert(null.asInstanceOf[Int], 1) map.insert(1, null.asInstanceOf[Int]) map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int]) + assert(map.numSpills > 0, "map did not spill") val it = map.iterator while (it.hasNext) { @@ -397,17 +407,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } test("external aggregation updates peak execution memory") { + val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter - .set("spark.testing.memory", (10 * 1024 * 1024).toString) + .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) sc = new SparkContext("local", "test", conf) // No spilling AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { - sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + assertNotSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold / 2, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } } // With spilling AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { - sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + assertSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold * 3, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 651c7eaa65ff5..e2cb791771d99 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -18,535 +18,92 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer - import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078) class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { - private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { - val conf = new SparkConf(loadDefaults) - if (kryo) { - conf.set("spark.serializer", classOf[KryoSerializer].getName) - } else { - // Make the Java serializer write a reset instruction (TC_RESET) after each object to test - // for a bug we had with bytes written past the last object in a batch (SPARK-2792) - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", classOf[JavaSerializer].getName) - } - conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") - // Ensure that we actually have multiple batches per spill file - conf.set("spark.shuffle.spill.batchSize", "10") - conf.set("spark.testing.memory", "2000000") - conf - } - - test("empty data stream with kryo ser") { - emptyDataStream(createSparkConf(false, true)) - } - - test("empty data stream with java ser") { - emptyDataStream(createSparkConf(false, false)) - } - - def emptyDataStream(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - - // Both aggregator and ordering - val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - assert(sorter.iterator.toSeq === Seq()) - sorter.stop() - - // Only aggregator - val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), None, None) - assert(sorter2.iterator.toSeq === Seq()) - sorter2.stop() - - // Only ordering - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - assert(sorter3.iterator.toSeq === Seq()) - sorter3.stop() - - // Neither aggregator nor ordering - val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), None, None) - assert(sorter4.iterator.toSeq === Seq()) - sorter4.stop() - } + import TestUtils.{assertNotSpilled, assertSpilled} - test("few elements per partition with kryo ser") { - fewElementsPerPartition(createSparkConf(false, true)) - } + testWithMultipleSer("empty data stream")(emptyDataStream) - test("few elements per partition with java ser") { - fewElementsPerPartition(createSparkConf(false, false)) - } + testWithMultipleSer("few elements per partition")(fewElementsPerPartition) - def fewElementsPerPartition(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - val elements = Set((1, 1), (2, 2), (5, 5)) - val expected = Set( - (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), - (5, Set((5, 5))), (6, Set())) - - // Both aggregator and ordering - val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), Some(ord), None) - sorter.insertAll(elements.iterator) - assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter.stop() - - // Only aggregator - val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), None, None) - sorter2.insertAll(elements.iterator) - assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter2.stop() + testWithMultipleSer("empty partitions with spilling")(emptyPartitionsWithSpilling) - // Only ordering - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) - sorter3.insertAll(elements.iterator) - assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter3.stop() - - // Neither aggregator nor ordering - val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) - sorter4.insertAll(elements.iterator) - assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) - sorter4.stop() - } - - test("empty partitions with spilling with kryo ser") { - emptyPartitionsWithSpilling(createSparkConf(false, true)) + // Load defaults, otherwise SPARK_HOME is not found + testWithMultipleSer("spilling in local cluster", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 2) } - test("empty partitions with spilling with java ser") { - emptyPartitionsWithSpilling(createSparkConf(false, false)) - } - - def emptyPartitionsWithSpilling(conf: SparkConf) { - conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) - sorter.insertAll(elements) - assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled - val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) - assert(iter.next() === (0, Nil)) - assert(iter.next() === (1, List((1, 1)))) - assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) - assert(iter.next() === (3, Nil)) - assert(iter.next() === (4, Nil)) - assert(iter.next() === (5, List((5, 5)))) - assert(iter.next() === (6, Nil)) - sorter.stop() - } - - test("spilling in local cluster with kryo ser") { - // Load defaults, otherwise SPARK_HOME is not found - testSpillingInLocalCluster(createSparkConf(true, true)) - } - - test("spilling in local cluster with java ser") { - // Load defaults, otherwise SPARK_HOME is not found - testSpillingInLocalCluster(createSparkConf(true, false)) - } - - def testSpillingInLocalCluster(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - - // reduceByKey - should spill ~8 times - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max).collect() - assert(resultA.length == 50000) - resultA.foreach { case(k, v) => - if (v != k * 2 + 1) { - fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") - } - } - - // groupByKey - should spill ~17 times - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey().collect() - assert(resultB.length == 25000) - resultB.foreach { case(i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - if (seq.toSet != expected) { - fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") - } - } - - // cogroup - should spill ~7 times - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2).collect() - assert(resultC.length == 10000) - resultC.foreach { case(i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet == Set[Int](0)) - assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet == Set[Int](1)) - assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet == Set[Int](5000)) - assert(seq2.toSet == Set[Int]()) - case 9999 => - assert(seq1.toSet == Set[Int](9999)) - assert(seq2.toSet == Set[Int]()) - case _ => - } - } - - // larger cogroup - should spill ~7 times - val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val resultD = rddD1.cogroup(rddD2).collect() - assert(resultD.length == 5000) - resultD.foreach { case(i, (seq1, seq2)) => - val expected = Set(i * 2, i * 2 + 1) - if (seq1.toSet != expected) { - fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") - } - if (seq2.toSet != expected) { - fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") - } - } - - // sortByKey - should spill ~17 times - val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultE = rddE.sortByKey().collect().toSeq - assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) - } - - test("spilling in local cluster with many reduce tasks with kryo ser") { - spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true)) - } - - test("spilling in local cluster with many reduce tasks with java ser") { - spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false)) - } - - def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) - - // reduceByKey - should spill ~4 times per executor - val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) - val resultA = rddA.reduceByKey(math.max _, 100).collect() - assert(resultA.length == 50000) - resultA.foreach { case(k, v) => - if (v != k * 2 + 1) { - fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") - } - } - - // groupByKey - should spill ~8 times per executor - val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultB = rddB.groupByKey(100).collect() - assert(resultB.length == 25000) - resultB.foreach { case(i, seq) => - val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - if (seq.toSet != expected) { - fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") - } - } - - // cogroup - should spill ~4 times per executor - val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) - val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) - val resultC = rddC1.cogroup(rddC2, 100).collect() - assert(resultC.length == 10000) - resultC.foreach { case(i, (seq1, seq2)) => - i match { - case 0 => - assert(seq1.toSet == Set[Int](0)) - assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) - case 1 => - assert(seq1.toSet == Set[Int](1)) - assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) - case 5000 => - assert(seq1.toSet == Set[Int](5000)) - assert(seq2.toSet == Set[Int]()) - case 9999 => - assert(seq1.toSet == Set[Int](9999)) - assert(seq2.toSet == Set[Int]()) - case _ => - } - } - - // larger cogroup - should spill ~4 times per executor - val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) - val resultD = rddD1.cogroup(rddD2).collect() - assert(resultD.length == 5000) - resultD.foreach { case(i, (seq1, seq2)) => - val expected = Set(i * 2, i * 2 + 1) - if (seq1.toSet != expected) { - fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") - } - if (seq2.toSet != expected) { - fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") - } - } - - // sortByKey - should spill ~8 times per executor - val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) - val resultE = rddE.sortByKey().collect().toSeq - assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) + testWithMultipleSer("spilling in local cluster with many reduce tasks", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 100) } test("cleanup of intermediate files in sorter") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val ord = implicitly[Ordering[Int]] - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - - val sorter2 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) - sorter2.stop() - assert(diskBlockManager.getAllBlocks().length === 0) + cleanupIntermediateFilesInSorter(withFailures = false) } - test("cleanup of intermediate files in sorter if there are errors") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val ord = implicitly[Ordering[Int]] - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - intercept[SparkException] { - sorter.insertAll((0 until 120000).iterator.map(i => { - if (i == 119990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) + test("cleanup of intermediate files in sorter with failures") { + cleanupIntermediateFilesInSorter(withFailures = true) } test("cleanup of intermediate files in shuffle") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val data = sc.parallelize(0 until 100000, 2).map(i => (i, i)) - assert(data.reduceByKey(_ + _).count() === 100000) - - // After the shuffle, there should be only 4 files on disk: our two map output files and - // their index files. All other intermediate files should've been deleted. - assert(diskBlockManager.getAllFiles().length === 4) - } - - test("cleanup of intermediate files in shuffle with errors") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val data = sc.parallelize(0 until 100000, 2).map(i => { - if (i == 99990) { - throw new Exception("Intentional failure") - } - (i, i) - }) - intercept[SparkException] { - data.reduceByKey(_ + _).count() - } - - // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index. - // All other files (map 2's output and intermediate merge files) should've been deleted. - assert(diskBlockManager.getAllFiles().length === 2) - } - - test("no partial aggregation or sorting with kryo ser") { - noPartialAggregationOrSorting(createSparkConf(false, true)) - } - - test("no partial aggregation or sorting with java ser") { - noPartialAggregationOrSorting(createSparkConf(false, false)) - } - - def noPartialAggregationOrSorting(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) - } - - test("partial aggregation without spill with kryo ser") { - partialAggregationWithoutSpill(createSparkConf(false, true)) - } - - test("partial aggregation without spill with java ser") { - partialAggregationWithoutSpill(createSparkConf(false, false)) - } - - def partialAggregationWithoutSpill(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) + cleanupIntermediateFilesInShuffle(withFailures = false) } - test("partial aggregation with spill, no ordering with kryo ser") { - partialAggregationWIthSpillNoOrdering(createSparkConf(false, true)) + test("cleanup of intermediate files in shuffle with failures") { + cleanupIntermediateFilesInShuffle(withFailures = true) } - test("partial aggregation with spill, no ordering with java ser") { - partialAggregationWIthSpillNoOrdering(createSparkConf(false, false)) + testWithMultipleSer("no sorting or partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = false) } - def partialAggregationWIthSpillNoOrdering(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) + testWithMultipleSer("no sorting or partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = true) } - test("partial aggregation with spill, with ordering with kryo ser") { - partialAggregationWithSpillWithOrdering(createSparkConf(false, true)) + testWithMultipleSer("sorting, no partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = false) } - - test("partial aggregation with spill, with ordering with java ser") { - partialAggregationWithSpillWithOrdering(createSparkConf(false, false)) + testWithMultipleSer("sorting, no partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = true) } - def partialAggregationWithSpillWithOrdering(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - - // avoid combine before spill - sorter.insertAll((0 until 50000).iterator.map(i => (i , 2 * i))) - sorter.insertAll((0 until 50000).iterator.map(i => (i, 2 * i + 1))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet - val expected = (0 until 3).map(p => { - (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) - }).toSet - assert(results === expected) + testWithMultipleSer("partial aggregation, no sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = false) } - test("sorting without aggregation, no spill with kryo ser") { - sortingWithoutAggregationNoSpill(createSparkConf(false, true)) + testWithMultipleSer("partial aggregation, no sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = true) } - test("sorting without aggregation, no spill with java ser") { - sortingWithoutAggregationNoSpill(createSparkConf(false, false)) + testWithMultipleSer("partial aggregation and sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = false) } - def sortingWithoutAggregationNoSpill(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100).iterator.map(i => (i, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq - val expected = (0 until 3).map(p => { - (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) - }).toSeq - assert(results === expected) - } - - test("sorting without aggregation, with spill with kryo ser") { - sortingWithoutAggregationWithSpill(createSparkConf(false, true)) - } - - test("sorting without aggregation, with spill with java ser") { - sortingWithoutAggregationWithSpill(createSparkConf(false, false)) + testWithMultipleSer("partial aggregation and sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = true) } - def sortingWithoutAggregationWithSpill(conf: SparkConf) { - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val ord = implicitly[Ordering[Int]] - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) - val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq - val expected = (0 until 3).map(p => { - (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) - }).toSeq - assert(results === expected) - } + testWithMultipleSer("sort without breaking sorting contracts", loadDefaults = true)( + sortWithoutBreakingSortingContracts) test("spilling with hash collisions") { - val conf = createSparkConf(true, false) + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) - : ArrayBuffer[String] = buffer1 ++= buffer2 + def mergeCombiners( + buffer1: ArrayBuffer[String], + buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) @@ -574,10 +131,11 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(w1.hashCode === w2.hashCode) } - val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ + val toInsert = (1 to size).iterator.map(_.toString).map(s => (s, s)) ++ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) sorter.insertAll(toInsert) + assert(sorter.numSpills > 0, "sorter did not spill") // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -592,21 +150,21 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2.equals(expectedValue)) count += 1 } - assert(count === 100000 + collisionPairs.size * 2) + assert(count === size + collisionPairs.size * 2) } test("spilling with many hash collisions") { - val conf = createSparkConf(true, false) + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) - // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). - val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) + val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) sorter.insertAll(toInsert.iterator) - + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator var count = 0 while (it.hasNext) { @@ -614,11 +172,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { assert(kv._2 === 10) count += 1 } - assert(count === 10000) + assert(count === size) } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = createSparkConf(true, false) + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) @@ -629,10 +189,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.insertAll( - (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) - + (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator while (it.hasNext) { // Should not throw NoSuchElementException @@ -641,7 +200,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = createSparkConf(true, false) + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) @@ -655,12 +216,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), ("1", null.asInstanceOf[String]), (null.asInstanceOf[String], null.asInstanceOf[String]) )) - + assert(sorter.numSpills > 0, "sorter did not spill") val it = sorter.iterator while (it.hasNext) { // Should not throw NullPointerException @@ -668,16 +229,301 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } - test("sort without breaking sorting contracts with kryo ser") { - sortWithoutBreakingSortingContracts(createSparkConf(true, true)) + /* ============================= * + | Helper test utility methods | + * ============================= */ + + private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + if (kryo) { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + } else { + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", classOf[JavaSerializer].getName) + } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") + conf + } + + /** + * Run a test multiple times, each time with a different serializer. + */ + private def testWithMultipleSer( + name: String, + loadDefaults: Boolean = false)(body: (SparkConf => Unit)): Unit = { + test(name + " with kryo ser") { + body(createSparkConf(loadDefaults, kryo = true)) + } + test(name + " with java ser") { + body(createSparkConf(loadDefaults, kryo = false)) + } } - test("sort without breaking sorting contracts with java ser") { - sortWithoutBreakingSortingContracts(createSparkConf(true, false)) + /* =========================================== * + | Helper methods that contain the test body | + * =========================================== */ + + private def emptyDataStream(conf: SparkConf) { + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter.iterator.toSeq === Seq()) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), None, None) + assert(sorter2.iterator.toSeq === Seq()) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter3.iterator.toSeq === Seq()) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), None, None) + assert(sorter4.iterator.toSeq === Seq()) + sorter4.stop() + } + + private def fewElementsPerPartition(conf: SparkConf) { + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + sorter.insertAll(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), None, None) + sorter2.insertAll(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + sorter3.insertAll(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter4.insertAll(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter4.stop() + } + + private def emptyPartitionsWithSpilling(conf: SparkConf) { + val size = 1000 + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + sorter.insertAll(elements) + assert(sorter.numSpills > 0, "sorter did not spill") + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 1000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int) { + val size = 5000 + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .reduceByKey(math.max _, numReduceTasks) + .collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } + } + + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .groupByKey(numReduceTasks) + .collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } + } + + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2, numReduceTasks).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") + } + } + + assertSpilled(sc, "sortByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .sortByKey(numPartitions = numReduceTasks) + .collect() + val expected = (0 until size).map { i => (i / 2, i) }.toArray + assert(result.length === size) + result.zipWithIndex.foreach { case ((k, _), i) => + val (expectedKey, _) = expected(i) + assert(k === expectedKey, s"Value for $i was wrong: expected $expectedKey, got $k") + } + } + } + + private def cleanupIntermediateFilesInSorter(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + val diskBlockManager = sc.env.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + val expectedSize = if (withFailures) size - 1 else size + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + if (withFailures) { + intercept[SparkException] { + sorter.insertAll((0 until size).iterator.map { i => + if (i == size - 1) { throw new SparkException("intentional failure") } + (i, i) + }) + } + } else { + sorter.insertAll((0 until size).iterator.map(i => (i, i))) + } + assert(sorter.iterator.toSet === (0 until expectedSize).map(i => (i, i)).toSet) + assert(sorter.numSpills > 0, "sorter did not spill") + assert(diskBlockManager.getAllFiles().nonEmpty, "sorter did not spill") + sorter.stop() + assert(diskBlockManager.getAllFiles().isEmpty, "spilled files were not cleaned up") + } + + private def cleanupIntermediateFilesInShuffle(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + val diskBlockManager = sc.env.blockManager.diskBlockManager + val data = sc.parallelize(0 until size, 2).map { i => + if (withFailures && i == size - 1) { + throw new SparkException("intentional failure") + } + (i, i) + } + + assertSpilled(sc, "test shuffle cleanup") { + if (withFailures) { + intercept[SparkException] { + data.reduceByKey(_ + _).count() + } + // After the shuffle, there should be only 2 files on disk: the output of task 1 and + // its index. All other files (map 2's output and intermediate merge files) should + // have been deleted. + assert(diskBlockManager.getAllFiles().length === 2) + } else { + assert(data.reduceByKey(_ + _).count() === size) + // After the shuffle, there should be only 4 files on disk: the output of both tasks + // and their indices. All intermediate merge files should have been deleted. + assert(diskBlockManager.getAllFiles().length === 4) + } + } + } + + private def basicSorterTest( + conf: SparkConf, + withPartialAgg: Boolean, + withOrdering: Boolean, + withSpilling: Boolean) { + val size = 1000 + if (withSpilling) { + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + } + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val agg = + if (withPartialAgg) { + Some(new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)) + } else { + None + } + val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None + val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None) + sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) + if (withSpilling) { + assert(sorter.numSpills > 0, "sorter did not spill") + } else { + assert(sorter.numSpills === 0, "sorter spilled") + } + val results = sorter.partitionedIterator.map { case (p, vs) => (p, vs.toSet) }.toSet + val expected = (0 until 3).map { p => + var v = (0 until size).map { i => (i / 4, i) }.filter { case (k, _) => k % 3 == p }.toSet + if (withPartialAgg) { + v = v.groupBy(_._1).mapValues { s => s.map(_._2).sum }.toSet + } + (p, v.toSet) + }.toSet + assert(results === expected) } private def sortWithoutBreakingSortingContracts(conf: SparkConf) { + val size = 100000 + val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. @@ -690,17 +536,18 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } - val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + val testData = Array.tabulate(size) { _ => rand.nextInt().toString } val sorter1 = new ExternalSorter[String, String, String]( None, None, Some(wrongOrdering), None) val thrown = intercept[IllegalArgumentException] { sorter1.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter1.numSpills > 0, "sorter did not spill") sorter1.iterator } - assert(thrown.getClass() === classOf[IllegalArgumentException]) - assert(thrown.getMessage().contains("Comparison method violates its general contract")) + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert(thrown.getMessage.contains("Comparison method violates its general contract")) sorter1.stop() // Using aggregation and external spill to make sure ExternalSorter using @@ -716,6 +563,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( Some(agg), None, None, None) sorter2.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter2.numSpills > 0, "sorter did not spill") // To validate the hash ordering of key var minKey = Int.MinValue @@ -729,12 +577,23 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } test("sorting updates peak execution memory") { + val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false, kryo = false) .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) sc = new SparkContext("local", "test", conf) // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { - sc.parallelize(1 to 1000, 2).repartition(100).count() + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter without spilling") { + assertNotSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold / 2, 2).repartition(100).count() + } + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter with spilling") { + assertSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold * 3, 2).repartition(100).count() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 835f52fa566a2..c4358f409b6ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -68,6 +68,8 @@ private class GrantEverythingMemoryManager extends MemoryManager { blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def releaseExecutionMemory(numBytes: Long): Unit = { } + override def releaseStorageMemory(numBytes: Long): Unit = { } override def maxExecutionMemory: Long = Long.MaxValue override def maxStorageMemory: Long = Long.MaxValue } From 6a2359ff1f7ad2233af2c530313d6ec2ecf70d19 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Oct 2015 14:50:58 -0700 Subject: [PATCH 070/139] [SPARK-10412] [SQL] report memory usage for tungsten sql physical operator https://issues.apache.org/jira/browse/SPARK-10412 some screenshots: ### aggregate: ![screen shot 2015-10-12 at 2 23 11 pm](https://cloud.githubusercontent.com/assets/3182036/10439534/618320a4-70ef-11e5-94d8-62ea7f2d1531.png) ### join ![screen shot 2015-10-12 at 2 23 29 pm](https://cloud.githubusercontent.com/assets/3182036/10439537/6724797c-70ef-11e5-8f75-0cf5cbd42048.png) Author: Wenchen Fan Author: Wenchen Fan Closes #8931 from cloud-fan/viz. --- .../aggregate/TungstenAggregate.scala | 10 ++- .../TungstenAggregationIterator.scala | 10 ++- .../sql/execution/metric/SQLMetrics.scala | 72 +++++++++++++------ .../org/apache/spark/sql/execution/sort.scala | 16 +++++ .../sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/execution/ui/SQLListener.scala | 9 ++- .../sql/execution/ui/SparkPlanGraph.scala | 4 +- .../TungstenAggregationIteratorSuite.scala | 3 +- .../execution/metric/SQLMetricsSuite.scala | 13 +++- .../sql/execution/ui/SQLListenerSuite.scala | 20 +++--- 10 files changed, 116 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c342940e6e757..0d3a4b36c161b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -49,7 +49,9 @@ case class TungstenAggregate( override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override def outputsUnsafeRows: Boolean = true @@ -79,6 +81,8 @@ case class TungstenAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") + val dataSize = longMetric("dataSize") + val spillSize = longMetric("spillSize") /** * Set up the underlying unsafe data structures used before computing the parent partition. @@ -97,7 +101,9 @@ case class TungstenAggregate( child.output, testFallbackStartsAt, numInputRows, - numOutputRows) + numOutputRows, + dataSize, + spillSize) } /** Compute a partition using the iterator already set up previously. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index fe708a5f71f79..7cd0f7b81e46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -87,7 +87,9 @@ class TungstenAggregationIterator( originalInputAttributes: Seq[Attribute], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, - numOutputRows: LongSQLMetric) + numOutputRows: LongSQLMetric, + dataSize: LongSQLMetric, + spillSize: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { // The parent partition iterator, to be initialized later in `start` @@ -110,6 +112,10 @@ class TungstenAggregationIterator( s"$allAggregateExpressions should have no more than 2 kinds of modes.") } + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + // // The modes of AggregateExpressions. Right now, we can handle the following mode: // - Partial-only: @@ -842,6 +848,8 @@ class TungstenAggregationIterator( val mapMemory = hashMap.getPeakMemoryUsedBytes val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) val peakMemory = Math.max(mapMemory, sorterMemory) + dataSize += peakMemory + spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 7a2a98ec18cb8..075b7ad881112 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.metric +import org.apache.spark.util.Utils import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} /** @@ -35,6 +36,12 @@ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( */ private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + /** + * A function that defines how we aggregate the final accumulator results among all tasks, + * and represent it in string for a SQL physical operator. + */ + val stringValue: Seq[T] => String + def zero: R } @@ -63,26 +70,12 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr override def value: Long = _value } -/** - * A wrapper of Int to avoid boxing and unboxing when using Accumulator - */ -private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { - - def add(term: Int): IntSQLMetricValue = { - _value += term - this - } - - // Although there is a boxing here, it's fine because it's only called in SQLListener - override def value: Int = _value -} - /** * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's * `+=` and `add`. */ -private[sql] class LongSQLMetric private[metric](name: String) - extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { +private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam) + extends SQLMetric[LongSQLMetricValue, Long](name, param) { override def +=(term: Long): Unit = { localValue.add(term) @@ -93,7 +86,8 @@ private[sql] class LongSQLMetric private[metric](name: String) } } -private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { +private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long) + extends SQLMetricParam[LongSQLMetricValue, Long] { override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) @@ -102,20 +96,56 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero - override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) + override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } private[sql] object SQLMetrics { - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - val acc = new LongSQLMetric(name) + private def createLongMetric( + sc: SparkContext, + name: String, + stringValue: Seq[Long] => String, + initialValue: Long): LongSQLMetric = { + val param = new LongSQLMetricParam(stringValue, initialValue) + val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + createLongMetric(sc, name, _.sum.toString, 0L) + } + + /** + * Create a metric to report the size information (including total, min, med, max) like data size, + * spill size, etc. + */ + def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + val stringValue = (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + } + // The final result of this metric in physical operator UI may looks like: + // data size total (min, med, max): + // 100GB (100MB, 1GB, 10GB) + createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + } + /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null") + val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 27f26245a5ef0..9385e5734db5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -93,10 +94,17 @@ case class TungstenSort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + override private[sql] lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output + val dataSize = longMetric("dataSize") + val spillSize = longMetric("spillSize") + /** * Set up the sorter in each partition before computing the parent partition. * This makes sure our sorter is not starved by other sorters used in the same task. @@ -131,7 +139,15 @@ case class TungstenSort( partitionIndex: Int, sorter: UnsafeExternalRowSorter, parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]]) + + dataSize += sorter.getPeakMemoryUsage + spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore + taskContext.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index a4dbd2e1978d0..e74d6fb396e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -100,7 +100,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") // scalastyle:on } - private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.nodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
{node.desc}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index d6472400a6a21..b302b519998ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -252,7 +252,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi /** * Get all accumulator updates from all tasks which belong to this execution and merge them. */ - def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized { + def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized { _executionIdToData.get(executionId) match { case Some(executionUIData) => val accumulatorUpdates = { @@ -264,8 +264,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).metricParam). - mapValues(_.asInstanceOf[SQLMetricValue[_]].value) + executionUIData.accumulatorMetrics(accumulatorId).metricParam) case None => // This execution has been dropped Map.empty @@ -274,11 +273,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => val param = paramFunc(accumulatorId) (accumulatorId, - values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) + param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index ae3d752dde348..f1fce5478a3fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[ui] case class SparkPlanGraph( nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { - def makeDotFile(metrics: Map[Long, Any]): String = { + def makeDotFile(metrics: Map[Long, String]): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) @@ -87,7 +87,7 @@ private[sql] object SparkPlanGraph { private[ui] case class SparkPlanGraphNode( id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { - def makeDotNode(metricsValue: Map[Long, Any]): String = { + def makeDotNode(metricsValue: Map[Long, String]): String = { val values = { for (metric <- metrics; value <- metricsValue.get(metric.accumulatorId)) yield { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 0cc4988ff681c..cc0ac1b07c21a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -39,7 +39,8 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, - 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + 0, Seq.empty, newMutableProjection, Seq.empty, None, + dummyAccum, dummyAccum, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 6afffae161ef6..cdd885ba14203 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -93,7 +93,16 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { }.toMap (node.id, node.name -> nodeMetrics) }.toMap - assert(expectedMetrics === actualMetrics) + + assert(expectedMetrics.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetrics.keySet) { + val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) + assert(expectedNodeName === actualNodeName) + for (metricName <- expectedMetricsMap.keySet) { + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + } + } } else { // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. // Since we cannot track all jobs, the metric values could be wrong and we should not check @@ -489,7 +498,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. - assert(metricValues.values.toSeq === Seq(2L)) + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 727cf3665a871..cc1c1e10e98c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -74,6 +74,10 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("basic") { + def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { + assert(actual === expected.mapValues(_.toString)) + } + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame @@ -114,7 +118,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, metrics) @@ -122,7 +126,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) // Retrying a stage should reset the metrics listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) @@ -133,7 +137,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Ignore the task end for the first attempt listener.onTaskEnd(SparkListenerTaskEnd( @@ -144,7 +148,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(0, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -162,7 +166,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(1, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) // Summit a new stage listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) @@ -173,7 +177,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -191,7 +195,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(1, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) assert(executionUIData.runningJobs === Seq(0)) assert(executionUIData.succeededJobs.isEmpty) @@ -208,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(executionUIData.succeededJobs === Seq(0)) assert(executionUIData.failedJobs.isEmpty) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { From eb0b4d6e2ddfb765f082d0d88472626336ad2609 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 15 Oct 2015 17:36:55 -0700 Subject: [PATCH 071/139] [SPARK-11135] [SQL] Exchange incorrectly skips sorts when existing ordering is non-empty subset of required ordering In Spark SQL, the Exchange planner tries to avoid unnecessary sorts in cases where the data has already been sorted by a superset of the requested sorting columns. For instance, let's say that a query calls for an operator's input to be sorted by `a.asc` and the input happens to already be sorted by `[a.asc, b.asc]`. In this case, we do not need to re-sort the input. The converse, however, is not true: if the query calls for `[a.asc, b.asc]`, then `a.asc` alone will not satisfy the ordering requirements, requiring an additional sort to be planned by Exchange. However, the current Exchange code gets this wrong and incorrectly skips sorting when the existing output ordering is a subset of the required ordering. This is simple to fix, however. This bug was introduced in https://github.com/apache/spark/pull/7458, so it affects 1.5.0+. This patch fixes the bug and significantly improves the unit test coverage of Exchange's sort-planning logic. Author: Josh Rosen Closes #9140 from JoshRosen/SPARK-11135. --- .../apache/spark/sql/execution/Exchange.scala | 5 +- .../spark/sql/execution/PlannerSuite.scala | 49 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 289453753f18d..1d3379a5e2d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -219,6 +219,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering var children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) // Ensure that the operator's children satisfy their output distribution requirements: children = children.zip(requiredChildDistributions).map { case (child, distribution) => @@ -248,8 +250,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min - if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child) } else { child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index cafa1d5154788..ebdab1c26d7bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -354,6 +354,55 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements adds sort when there is no existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, + requiredChildOrdering = Seq(Seq(orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA, orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + // --------------------------------------------------------------------------------------------- } From 43f5d1f326d7a2a4a78fe94853d0d05237568203 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Oct 2015 11:53:47 +0100 Subject: [PATCH 072/139] [SPARK-11060] [STREAMING] Fix some potential NPE in DStream transformation This patch fixes: 1. Guard out against NPEs in `TransformedDStream` when parent DStream returns None instead of empty RDD. 2. Verify some input streams which will potentially return None. 3. Add unit test to verify the behavior when input stream returns None. cc tdas , please help to review, thanks a lot :). Author: jerryshao Closes #9070 from jerryshao/SPARK-11060. --- .../dstream/ConstantInputDStream.scala | 6 +- .../streaming/dstream/QueueInputDStream.scala | 2 +- .../dstream/TransformedDStream.scala | 7 +- .../streaming/dstream/UnionDStream.scala | 11 ++-- .../streaming/BasicOperationsSuite.scala | 66 +++++++++++++++++++ 5 files changed, 83 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index f396c347581ce..4eb92dd8b1053 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, StreamingContext} -import scala.reflect.ClassTag /** * An input stream that always returns the same RDD on each timestep. Useful for testing. @@ -27,6 +28,9 @@ import scala.reflect.ClassTag class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T]) extends InputDStream[T](ssc_) { + require(rdd != null, + "parameter rdd null is illegal, which will lead to NPE in the following transformation") + override def start() {} override def stop() {} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a2685046e03d4..cd073646370d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -62,7 +62,7 @@ class QueueInputDStream[T: ClassTag]( } else if (defaultRDD != null) { Some(defaultRDD) } else { - None + Some(ssc.sparkContext.emptyRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index ab01f47d5cf99..5eabdf63dc8d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark.SparkException -import org.apache.spark.rdd.{PairRDDFunctions, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} private[streaming] @@ -39,7 +39,10 @@ class TransformedDStream[U: ClassTag] ( override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { - val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq + val parentRDDs = parents.map { parent => parent.getOrCompute(validTime).getOrElse( + // Guard out against parent DStream that return None instead of Some(rdd) to avoid NPE + throw new SparkException(s"Couldn't generate RDD from parent at time $validTime")) + } val transformedRDD = transformFunc(parentRDDs, validTime) if (transformedRDD == null) { throw new SparkException("Transform function must not return null. " + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 9405dbaa12329..d73ffdfd84d2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.SparkException import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { @@ -41,8 +42,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) val rdds = new ArrayBuffer[RDD[T]]() parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " - + validTime) + case None => throw new SparkException("Could not generate RDD from a parent for unifying at" + + s" time $validTime") } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 9988f410f0bc1..9d296c6d3ef8b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -191,6 +191,20 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("union with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 8, 101 to 108, 201 to 208) + intercept[SparkException] { + testOperation( + input, + (s: DStream[Int]) => s.union(s.map(_ + 4)), + output, + input.length, + false + ) + } + } + test("StreamingContext.union") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -224,6 +238,19 @@ class BasicOperationsSuite extends TestSuiteBase { } } + test("transform with input stream return None") { + val input = Seq(1 to 4, 5 to 8, null) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)), + input.filterNot(_ == null).map(_.map(_.toString)), + input.length, + false + ) + } + } + test("transformWith") { val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) @@ -244,6 +271,27 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData1, inputData2, operation, outputData, true) } + test("transformWith with input stream return None") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), null ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), null ) + val outputData = Seq( + Seq("a", "b", "a", "b"), + Seq("a", "b", "", ""), + Seq("") + ) + + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.transformWith( // RDD.join in transform + s2, + (rdd1: RDD[String], rdd2: RDD[String]) => rdd1.union(rdd2) + ) + } + + intercept[SparkException] { + testOperation(inputData1, inputData2, operation, outputData, inputData1.length, true) + } + } + test("StreamingContext.transform") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -260,6 +308,24 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(input, operation, output) } + test("StreamingContext.transform with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 12, 101 to 112, 201 to 212) + + // transform over 3 DStreams by doing union of the 3 RDDs + val operation = (s: DStream[Int]) => { + s.context.transform( + Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams + (rdds: Seq[RDD[_]], time: Time) => + rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs + ) + } + + intercept[SparkException] { + testOperation(input, operation, output, input.length, false) + } + } + test("cogroup") { val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) From ed775042cceb61a0566502e1306ac3c70f4a6a5f Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 16 Oct 2015 12:03:05 +0100 Subject: [PATCH 073/139] [SPARK-11092] [DOCS] Add source links to scaladoc generation Modify the SBT build script to include GitHub source links for generated Scaladocs, on releases only (no snapshots). Author: Jakob Odersky Closes #9110 from jodersky/unidoc. --- project/SparkBuild.scala | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1339980c38800..8f0f310ddd24e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -156,6 +156,10 @@ object SparkBuild extends PomBuild { javacOptions in Compile ++= Seq("-encoding", "UTF-8"), + scalacOptions in Compile ++= Seq( + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc + ), + // Implements -Xfatal-warnings, ignoring deprecation warnings. // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. compile in Compile := { @@ -489,6 +493,8 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } + val unidocSourceBase = settingKey[String]("Base URL of source links in Scaladoc.") + lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, @@ -531,8 +537,19 @@ object Unidoc { "-noqualifier", "java.lang" ), - // Group similar methods together based on the @group annotation. - scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") + // Use GitHub repository for Scaladoc source linke + unidocSourceBase := s"https://github.com/apache/spark/tree/v${version.value}", + + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( + "-groups" // Group similar methods together based on the @group annotation. + ) ++ ( + // Add links to sources when generating Scaladoc for a non-snapshot release + if (!isSnapshot.value) { + Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH}.scala") + } else { + Seq() + } + ) ) } From 08698ee1d6f29b2c999416f18a074d5193cdacd5 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 16 Oct 2015 14:26:34 +0100 Subject: [PATCH 074/139] [SPARK-11094] Strip extra strings from Java version in test runner Removes any extra strings from the Java version, fixing subsequent integer parsing. This is required since some OpenJDK versions (specifically in Debian testing), append an extra "-internal" string to the version field. Author: Jakob Odersky Closes #9111 from jodersky/fixtestrunner. --- dev/run-tests.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 1a816585187d9..d4d6880491bc8 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -176,17 +176,14 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' - version, update = version_str.split('_') # eg ['1.8.0', '25'] + match = re.search('(\d+)\.(\d+)\.(\d+)_(\d+)', raw_version_str) - # map over the values and convert them to integers - version_info = [int(x) for x in version.split('.') + [update]] - - return JavaVersion(major=version_info[0], - minor=version_info[1], - patch=version_info[2], - update=version_info[3]) + major = int(match.group(1)) + minor = int(match.group(2)) + patch = int(match.group(3)) + update = int(match.group(4)) + return JavaVersion(major, minor, patch, update) # ------------------------------------------------------------------------------------------------- # Functions for running the other build and test scripts From 4ee2cea2a43f7d04ab8511d9c029f80c5dadd48e Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 16 Oct 2015 17:24:18 +0100 Subject: [PATCH 075/139] [SPARK-11122] [BUILD] [WARN] Add tag to fatal warnings Shows that an error is actually due to a fatal warning. Author: Jakob Odersky Closes #9128 from jodersky/fatalwarnings. --- project/SparkBuild.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8f0f310ddd24e..766edd9500c30 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -164,7 +164,7 @@ object SparkBuild extends PomBuild { // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. compile in Compile := { val analysis = (compile in Compile).value - val s = streams.value + val out = streams.value def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = { l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message) @@ -181,7 +181,14 @@ object SparkBuild extends PomBuild { failed = failed + 1 } - logProblem(if (deprecation) s.log.warn else s.log.error, k, p) + val printer: (=> String) => Unit = s => if (deprecation) { + out.log.warn(s) + } else { + out.log.error("[warn] " + s) + } + + logProblem(printer, k, p) + } } From b9c5e5d4ac4c9fe29e880f4ee562a9c552e81d29 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 16 Oct 2015 11:19:37 -0700 Subject: [PATCH 076/139] [SPARK-11124] JsonParser/Generator should be closed for resource recycle Some json parsers are not closed. parser in JacksonParser#parseJson, for example. Author: navis.ryu Closes #9130 from navis/SPARK-11124. --- .../scala/org/apache/spark/util/Utils.scala | 4 ++ .../expressions/jsonExpressions.scala | 56 +++++++++---------- .../datasources/json/InferSchema.scala | 8 ++- .../datasources/json/JacksonParser.scala | 41 +++++++------- 4 files changed, 57 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bd7e51c3b5100..22c05a2479422 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2153,6 +2153,10 @@ private[spark] object Utils extends Logging { conf.getInt("spark.executor.instances", 0) == 0 } + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 0770fab0ae901..8c9853e628d2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils import scala.util.parsing.combinator.RegexParsers @@ -134,16 +135,18 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - val parser = jsonFactory.createParser(jsonStr.getBytes) - val output = new ByteArrayOutputStream() - val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) - parser.nextToken() - val matched = evaluatePath(parser, generator, RawStyle, parsed.get) - generator.close() - if (matched) { - UTF8String.fromBytes(output.toByteArray) - } else { - null + Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + val output = new ByteArrayOutputStream() + val matched = Utils.tryWithResource( + jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => + parser.nextToken() + evaluatePath(parser, generator, RawStyle, parsed.get) + } + if (matched) { + UTF8String.fromBytes(output.toByteArray) + } else { + null + } } } catch { case _: JsonProcessingException => null @@ -250,17 +253,18 @@ case class GetJsonObject(json: Expression, path: Expression) // temporarily buffer child matches, the emitted json will need to be // modified slightly if there is only a single element written val buffer = new StringWriter() - val flattenGenerator = jsonFactory.createGenerator(buffer) - flattenGenerator.writeStartArray() var dirty = 0 - while (p.nextToken() != END_ARRAY) { - // track the number of array elements and only emit an outer array if - // we've written more than one element, this matches Hive's behavior - dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + Utils.tryWithResource(jsonFactory.createGenerator(buffer)) { flattenGenerator => + flattenGenerator.writeStartArray() + + while (p.nextToken() != END_ARRAY) { + // track the number of array elements and only emit an outer array if + // we've written more than one element, this matches Hive's behavior + dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + } + flattenGenerator.writeEndArray() } - flattenGenerator.writeEndArray() - flattenGenerator.close() val buf = buffer.getBuffer if (dirty > 1) { @@ -370,12 +374,8 @@ case class JsonTuple(children: Seq[Expression]) } try { - val parser = jsonFactory.createParser(json.getBytes) - - try { - parseRow(parser, input) - } finally { - parser.close() + Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { + parser => parseRow(parser, input) } } catch { case _: JsonProcessingException => @@ -420,12 +420,8 @@ case class JsonTuple(children: Seq[Expression]) // write the output directly to UTF8 encoded byte array if (parser.nextToken() != JsonToken.VALUE_NULL) { - val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) - - try { - copyCurrentStructure(generator, parser) - } finally { - generator.close() + Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { + generator => copyCurrentStructure(generator, parser) } row(idx) = UTF8String.fromBytes(output.toByteArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index b6f3410bad690..d0780028dacb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils private[sql] object InferSchema { /** @@ -47,9 +48,10 @@ private[sql] object InferSchema { val factory = new JsonFactory() iter.map { row => try { - val parser = factory.createParser(row) - parser.nextToken() - inferField(parser) + Utils.tryWithResource(factory.createParser(row)) { parser => + parser.nextToken() + inferField(parser) + } } catch { case _: JsonParseException => StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index c51140749c8e6..09b8a9e936a1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils private[sql] object JacksonParser { def apply( @@ -86,9 +87,9 @@ private[sql] object JacksonParser { case (_, StringType) => val writer = new ByteArrayOutputStream() - val generator = factory.createGenerator(writer, JsonEncoding.UTF8) - generator.copyCurrentStructure(parser) - generator.close() + Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { + generator => generator.copyCurrentStructure(parser) + } UTF8String.fromBytes(writer.toByteArray) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => @@ -245,22 +246,24 @@ private[sql] object JacksonParser { iter.flatMap { record => try { - val parser = factory.createParser(record) - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or " + + "an array of JSON objects.") + } } } catch { case _: JsonProcessingException => From 3d683a139b333456a6bd8801ac5f113d1ac3fd18 Mon Sep 17 00:00:00 2001 From: Pravin Gadakh Date: Fri, 16 Oct 2015 13:38:50 -0700 Subject: [PATCH 077/139] [SPARK-10581] [DOCS] Groups are not resolved in scaladoc in sql classes Groups are not resolved properly in scaladoc in following classes: sql/core/src/main/scala/org/apache/spark/sql/Column.scala sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala sql/core/src/main/scala/org/apache/spark/sql/functions.scala Author: Pravin Gadakh Closes #9148 from pravingadakh/master. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 8 ++++---- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 807bc8c30c12d..1f826887ac774 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -41,10 +41,10 @@ private[sql] object Column { * :: Experimental :: * A column in a [[DataFrame]]. * - * @groupname java_expr_ops Java-specific expression operators. - * @groupname expr_ops Expression operators. - * @groupname df_ops DataFrame functions. - * @groupname Ungrouped Support functions for DataFrames. + * @groupname java_expr_ops Java-specific expression operators + * @groupname expr_ops Expression operators + * @groupname df_ops DataFrame functions + * @groupname Ungrouped Support functions for DataFrames * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 361eb576c567a..e83657a60558d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -59,7 +59,7 @@ import org.apache.spark.util.Utils * @groupname specificdata Specific Data Sources * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation - * @groupname Ungrouped Support functions for language integrated queries. + * @groupname Ungrouped Support functions for language integrated queries * * @since 1.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2467b4e48415b..15c864a8ab641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname collection_funcs Collection functions - * @groupname Ungrouped Support functions for DataFrames. + * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ @Experimental From 369d786f58580e7df73e7e23f27390d37269d0de Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 16 Oct 2015 13:53:06 -0700 Subject: [PATCH 078/139] [SPARK-10974] [STREAMING] Add progress bar for output operation column and use red dots for failed batches Screenshot: 1 Also fixed the description and duration for output operations that don't have spark jobs. 2 Author: zsxwing Closes #9010 from zsxwing/output-op-progress-bar. --- .../streaming/ui/static/streaming-page.js | 26 ++- .../apache/spark/streaming/DStreamGraph.scala | 2 +- .../spark/streaming/scheduler/BatchInfo.scala | 23 +-- .../spark/streaming/scheduler/Job.scala | 30 ++- .../streaming/scheduler/JobScheduler.scala | 12 +- .../spark/streaming/scheduler/JobSet.scala | 17 +- .../scheduler/OutputOperationInfo.scala | 6 +- .../spark/streaming/ui/AllBatchesTable.scala | 40 ++-- .../apache/spark/streaming/ui/BatchPage.scala | 174 +++++++----------- .../spark/streaming/ui/BatchUIData.scala | 67 ++++++- .../ui/StreamingJobProgressListener.scala | 14 ++ .../streaming/StreamingListenerSuite.scala | 16 +- .../spark/streaming/UISeleniumSuite.scala | 2 +- .../StreamingJobProgressListenerSuite.scala | 30 +-- 14 files changed, 258 insertions(+), 201 deletions(-) diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 4886b68eeaf76..f82323a1cdd94 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -154,34 +154,40 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { var lastClickedBatch = null; var lastTimeout = null; + function isFailedBatch(batchTime) { + return $("#batch-" + batchTime).attr("isFailed") == "true"; + } + // Add points to the line. However, we make it invisible at first. But when the user moves mouse // over a point, it will be displayed with its detail. svg.selectAll(".point") .data(data) .enter().append("circle") - .attr("stroke", "white") // white and opacity = 0 make it invisible - .attr("fill", "white") - .attr("opacity", "0") + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) // white and opacity = 0 make it invisible + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) .style("cursor", "pointer") .attr("cx", function(d) { return x(d.x); }) .attr("cy", function(d) { return y(d.y); }) - .attr("r", function(d) { return 3; }) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}) .on('mouseover', function(d) { var tip = formatYValue(d.y) + " " + unitY + " at " + timeFormat[d.x]; showBootstrapTooltip(d3.select(this).node(), tip); // show the point d3.select(this) - .attr("stroke", "steelblue") - .attr("fill", "steelblue") - .attr("opacity", "1"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("opacity", "1") + .attr("r", "3"); }) .on('mouseout', function() { hideBootstrapTooltip(d3.select(this).node()); // hide the point d3.select(this) - .attr("stroke", "white") - .attr("fill", "white") - .attr("opacity", "0"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}); }) .on("click", function(d) { if (lastTimeout != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index de79c9ef1abfa..1b0b7890b3b00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -113,7 +113,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { val jobs = this.synchronized { outputStreams.flatMap { outputStream => val jobOption = outputStream.generateJob(time) - jobOption.foreach(_.setCallSite(outputStream.creationSite.longForm)) + jobOption.foreach(_.setCallSite(outputStream.creationSite)) jobOption } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 463f899dc249b..436eb0a566141 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.streaming.Time * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing * @param processingEndTime Clock time of when the last job of this batch finished processing + * @param outputOperationInfos The output operations in this batch */ @DeveloperApi case class BatchInfo( @@ -36,13 +37,10 @@ case class BatchInfo( streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], - processingEndTime: Option[Long] + processingEndTime: Option[Long], + outputOperationInfos: Map[Int, OutputOperationInfo] ) { - private var _failureReasons: Map[Int, String] = Map.empty - - private var _numOutputOp: Int = 0 - @deprecated("Use streamIdToInputInfo instead", "1.5.0") def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) @@ -72,19 +70,4 @@ case class BatchInfo( */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum - /** Set the failure reasons corresponding to every output ops in the batch */ - private[streaming] def setFailureReason(reasons: Map[Int, String]): Unit = { - _failureReasons = reasons - } - - /** Failure reasons corresponding to every output ops in the batch */ - private[streaming] def failureReasons = _failureReasons - - /** Set the number of output operations in this batch */ - private[streaming] def setNumOutputOp(numOutputOp: Int): Unit = { - _numOutputOp = numOutputOp - } - - /** Return the number of output operations in this batch */ - private[streaming] def numOutputOp: Int = _numOutputOp } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 1373053f064f3..ab1b3565fcc19 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming.scheduler +import scala.util.{Failure, Try} + import org.apache.spark.streaming.Time -import scala.util.Try +import org.apache.spark.util.{Utils, CallSite} /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -29,7 +31,9 @@ class Job(val time: Time, func: () => _) { private var _outputOpId: Int = _ private var isSet = false private var _result: Try[_] = null - private var _callSite: String = "Unknown" + private var _callSite: CallSite = null + private var _startTime: Option[Long] = None + private var _endTime: Option[Long] = None def run() { _result = Try(func()) @@ -71,11 +75,29 @@ class Job(val time: Time, func: () => _) { _outputOpId = outputOpId } - def setCallSite(callSite: String): Unit = { + def setCallSite(callSite: CallSite): Unit = { _callSite = callSite } - def callSite: String = _callSite + def callSite: CallSite = _callSite + + def setStartTime(startTime: Long): Unit = { + _startTime = Some(startTime) + } + + def setEndTime(endTime: Long): Unit = { + _endTime = Some(endTime) + } + + def toOutputOperationInfo: OutputOperationInfo = { + val failureReason = if (_result != null && _result.isFailure) { + Some(Utils.exceptionString(_result.asInstanceOf[Failure[_]].exception)) + } else { + None + } + OutputOperationInfo( + time, outputOpId, callSite.shortForm, callSite.longForm, _startTime, _endTime, failureReason) + } override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0a4a396a0f498..2480b4ec093e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ -import scala.util.{Failure, Success} +import scala.util.Failure import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils} +import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} private[scheduler] sealed trait JobSchedulerEvent @@ -162,16 +162,16 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - listenerBus.post(StreamingListenerOutputOperationStarted( - OutputOperationInfo(job.time, job.outputOpId, job.callSite, Some(startTime), None))) + job.setStartTime(startTime) + listenerBus.post(StreamingListenerOutputOperationStarted(job.toOutputOperationInfo)) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } private def handleJobCompletion(job: Job, completedTime: Long) { val jobSet = jobSets.get(job.time) jobSet.handleJobCompletion(job) - listenerBus.post(StreamingListenerOutputOperationCompleted( - OutputOperationInfo(job.time, job.outputOpId, job.callSite, None, Some(completedTime)))) + job.setEndTime(completedTime) + listenerBus.post(StreamingListenerOutputOperationCompleted(job.toOutputOperationInfo)) logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) if (jobSet.hasCompleted) { jobSets.remove(jobSet.time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 08f63cc99268f..f76300351e3c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -64,24 +64,13 @@ case class JobSet( } def toBatchInfo: BatchInfo = { - val failureReasons: Map[Int, String] = { - if (hasCompleted) { - jobs.filter(_.result.isFailure).map { job => - (job.outputOpId, Utils.exceptionString(job.result.asInstanceOf[Failure[_]].exception)) - }.toMap - } else { - Map.empty - } - } - val binfo = new BatchInfo( + BatchInfo( time, streamIdToInputInfo, submissionTime, if (processingStartTime >= 0) Some(processingStartTime) else None, - if (processingEndTime >= 0) Some(processingEndTime) else None + if (processingEndTime >= 0) Some(processingEndTime) else None, + jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) - binfo.setFailureReason(failureReasons) - binfo.setNumOutputOp(jobs.size) - binfo } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala index d5614b343912b..137e512a670da 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala @@ -25,17 +25,21 @@ import org.apache.spark.streaming.Time * Class having information on output operations. * @param batchTime Time of the batch * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. * @param description The description of this output operation. * @param startTime Clock time of when the output operation started processing * @param endTime Clock time of when the output operation started processing + * @param failureReason Failure reason if this output operation fails */ @DeveloperApi case class OutputOperationInfo( batchTime: Time, id: Int, + name: String, description: String, startTime: Option[Long], - endTime: Option[Long]) { + endTime: Option[Long], + failureReason: Option[String]) { /** * Return the duration of this output operation. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 3e6590d66f587..125cafd41b8af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -17,9 +17,6 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date - import scala.xml.Node import org.apache.spark.ui.{UIUtils => SparkUIUtils} @@ -46,7 +43,8 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") val batchTimeId = s"batch-$batchTime" - + {formattedBatchTime} @@ -75,6 +73,19 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) batchTable } + protected def createOutputOperationProgressBar(batch: BatchUIData): Seq[Node] = { + + { + SparkUIUtils.makeProgressBar( + started = batch.numActiveOutputOp, + completed = batch.numCompletedOutputOp, + failed = batch.numFailedOutputOp, + skipped = 0, + total = batch.outputOperations.size) + } + + } + /** * Return HTML for all rows of this table. */ @@ -86,7 +97,10 @@ private[ui] class ActiveBatchTable( waitingBatches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { - override protected def columns: Seq[Node] = super.columns ++ Status + override protected def columns: Seq[Node] = super.columns ++ { + Output Ops: Succeeded/Total + Status + } override protected def renderRows: Seq[Node] = { // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display @@ -96,11 +110,11 @@ private[ui] class ActiveBatchTable( } private def runningBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ processing + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing } private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ queued + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued } } @@ -119,17 +133,11 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: private def completedBatchRow(batch: BatchUIData): Seq[Node] = { val totalDelay = batch.totalDelay val formattedTotalDelay = totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") - val numFailedOutputOp = batch.failureReason.size - val outputOpColumn = if (numFailedOutputOp > 0) { - s"${batch.numOutputOp - numFailedOutputOp}/${batch.numOutputOp}" + - s" (${numFailedOutputOp} failed)" - } else { - s"${batch.numOutputOp}/${batch.numOutputOp}" - } - baseRow(batch) ++ + + baseRow(batch) ++ { {formattedTotalDelay} - {outputOpColumn} + } ++ createOutputOperationProgressBar(batch) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index a19b85a51d289..2ed925572826e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,32 +47,30 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, - outputOpStatus: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, sparkJob: SparkJobIdWithUIData): Seq[Node] = { if (sparkJob.jobUIData.isDefined) { - generateNormalJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, - outputOpStatus, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) + generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) } else { - generateDroppedJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, - outputOpStatus, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) + generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) } } private def generateOutputOpRowWithoutSparkJobs( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], - formattedOutputOpDuration: String, - outputOpStatus: String): Seq[Node] = { + formattedOutputOpDuration: String): Seq[Node] = { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} {formattedOutputOpDuration} - {outputOpStatusCell(outputOpStatus, rowspan = 1)} + {outputOpStatusCell(outputOpData, rowspan = 1)} - @@ -91,10 +89,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of a output op. */ private def generateNormalJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, - outputOpStatus: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, sparkJob: JobUIData): Seq[Node] = { @@ -116,12 +113,12 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} {formattedOutputOpDuration} ++ - {outputOpStatusCell(outputOpStatus, numSparkJobRowsInOutputOp)} + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -161,10 +158,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * with "-" cells. */ private def generateDroppedJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, - outputOpStatus: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, jobId: Int): Seq[Node] = { @@ -173,10 +169,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} {formattedOutputOpDuration} ++ - {outputOpStatusCell(outputOpStatus, numSparkJobRowsInOutputOp)} + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -199,47 +195,34 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( - outputOpId: OutputOpId, - outputOpStatus: String, + outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - // We don't count the durations of dropped jobs - val sparkJobDurations = sparkJobs.filter(_.jobUIData.nonEmpty).map(_.jobUIData.get). - map(sparkJob => { - sparkJob.submissionTime.map { start => - val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) - end - start - } - }) val formattedOutputOpDuration = - if (sparkJobDurations.isEmpty || sparkJobDurations.exists(_ == None)) { - // If no job or any job does not finish, set "formattedOutputOpDuration" to "-" + if (outputOpData.duration.isEmpty) { "-" } else { - SparkUIUtils.formatDuration(sparkJobDurations.flatMap(x => x).sum) + SparkUIUtils.formatDuration(outputOpData.duration.get) } - val description = generateOutputOpDescription(sparkJobs) + val description = generateOutputOpDescription(outputOpData) if (sparkJobs.isEmpty) { - generateOutputOpRowWithoutSparkJobs( - outputOpId, description, formattedOutputOpDuration, outputOpStatus) + generateOutputOpRowWithoutSparkJobs(outputOpData, description, formattedOutputOpDuration) } else { val firstRow = generateJobRow( - outputOpId, + outputOpData, description, formattedOutputOpDuration, - outputOpStatus, sparkJobs.size, true, sparkJobs.head) val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( - outputOpId, + outputOpData, description, formattedOutputOpDuration, - outputOpStatus, sparkJobs.size, false, sparkJob) @@ -248,35 +231,18 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } - private def generateOutputOpDescription(sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - val lastStageInfo = - sparkJobs.flatMap(_.jobUIData).headOption. // Get the first JobUIData - flatMap { sparkJob => // For the first job, get the latest Stage info - if (sparkJob.stageIds.isEmpty) { - None - } else { - sparkListener.stageIdToInfo.get(sparkJob.stageIds.max) - } - } - lastStageInfo match { - case Some(stageInfo) => - val details = if (stageInfo.details.nonEmpty) { - - +details - ++ - - } else { - NodeSeq.Empty - } - -
{stageInfo.name} {details}
- case None => - Text("(Unknown)") - } + private def generateOutputOpDescription(outputOp: OutputOperationUIData): Seq[Node] = { +
+ {outputOp.name} + + +details + + +
} private def failureReasonCell( @@ -329,6 +295,19 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } + private def generateOutputOperationStatusForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.SparkException")) { + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.size + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + /** * Generate the job table for the batch. */ @@ -338,26 +317,15 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // sort SparkJobIds for each OutputOpId (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) } - val outputOps = (0 until batchUIData.numOutputOp).map { outputOpId => - val status = batchUIData.failureReason.get(outputOpId).map { failure => - if (failure.startsWith("org.apache.spark.SparkException")) { - "Failed due to Spark job error\n" + failure - } else { - var nextLineIndex = failure.indexOf("\n") - if (nextLineIndex < 0) { - nextLineIndex = failure.size - } - val firstLine = failure.substring(0, nextLineIndex) - s"Failed due to error: $firstLine\n$failure" - } - }.getOrElse("Succeeded") - val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) - (outputOpId, status, sparkJobIds) - } + + val outputOps: Seq[(OutputOperationUIData, Seq[SparkJobId])] = + batchUIData.outputOperations.map { case (outputOpId, outputOperation) => + val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) + (outputOperation, sparkJobIds) + }.toSeq.sortBy(_._1.id) sparkListener.synchronized { - val outputOpIdWithJobs: Seq[(OutputOpId, String, Seq[SparkJobIdWithUIData])] = - outputOps.map { case (outputOpId, status, sparkJobIds) => - (outputOpId, status, + val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => + (outputOpData, sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) } @@ -367,9 +335,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { - outputOpIdWithJobs.map { - case (outputOpId, status, sparkJobIds) => - generateOutputOpIdRow(outputOpId, status, sparkJobIds) + outputOpWithJobs.map { case (outputOpData, sparkJobIds) => + generateOutputOpIdRow(outputOpData, sparkJobIds) } } @@ -377,7 +344,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } - def render(request: HttpServletRequest): Seq[Node] = { + def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } @@ -430,14 +397,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { - val jobTable = - if (batchUIData.outputOpIdSparkJobIdPairs.isEmpty) { -
Cannot find any job for Batch {formattedBatchTime}.
- } else { - generateJobTable(batchUIData) - } - - val content = summary ++ jobTable + val content = summary ++ generateJobTable(batchUIData) SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } @@ -471,11 +431,17 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
")) } - private def outputOpStatusCell(status: String, rowspan: Int): Seq[Node] = { - if (status == "Succeeded") { - Succeeded - } else { - failureReasonCell(status, rowspan, includeFirstLineInExpandDetails = false) + private def outputOpStatusCell(outputOp: OutputOperationUIData, rowspan: Int): Seq[Node] = { + outputOp.failureReason match { + case Some(failureReason) => + val failureReasonForUI = generateOutputOperationStatusForUI(failureReason) + failureReasonCell(failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) + case None => + if (outputOp.endTime.isEmpty) { + - + } else { + Succeeded + } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index e6c2e2140c6c4..3ef3689de1c45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable + import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} +import org.apache.spark.streaming.scheduler.{BatchInfo, OutputOperationInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) @@ -30,8 +32,7 @@ private[ui] case class BatchUIData( val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], - val numOutputOp: Int, - val failureReason: Map[Int, String], + val outputOperations: mutable.HashMap[OutputOpId, OutputOperationUIData] = mutable.HashMap(), var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { /** @@ -61,19 +62,75 @@ private[ui] case class BatchUIData( * The number of recorders received by the receivers in this batch. */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum + + /** + * Update an output operation information of this batch. + */ + def updateOutputOperationInfo(outputOperationInfo: OutputOperationInfo): Unit = { + assert(batchTime == outputOperationInfo.batchTime) + outputOperations(outputOperationInfo.id) = OutputOperationUIData(outputOperationInfo) + } + + /** + * Return the number of failed output operations. + */ + def numFailedOutputOp: Int = outputOperations.values.count(_.failureReason.nonEmpty) + + /** + * Return the number of running output operations. + */ + def numActiveOutputOp: Int = outputOperations.values.count(_.endTime.isEmpty) + + /** + * Return the number of completed output operations. + */ + def numCompletedOutputOp: Int = outputOperations.values.count { + op => op.failureReason.isEmpty && op.endTime.nonEmpty + } + + /** + * Return if this batch has any output operations + */ + def isFailed: Boolean = numFailedOutputOp != 0 } private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { + val outputOperations = mutable.HashMap[OutputOpId, OutputOperationUIData]() + outputOperations ++= batchInfo.outputOperationInfos.mapValues(OutputOperationUIData.apply) new BatchUIData( batchInfo.batchTime, batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime, - batchInfo.numOutputOp, - batchInfo.failureReasons + outputOperations + ) + } +} + +private[ui] case class OutputOperationUIData( + id: OutputOpId, + name: String, + description: String, + startTime: Option[Long], + endTime: Option[Long], + failureReason: Option[String]) { + + def duration: Option[Long] = for (s <- startTime; e <- endTime) yield e - s +} + +private[ui] object OutputOperationUIData { + + def apply(outputOperationInfo: OutputOperationInfo): OutputOperationUIData = { + OutputOperationUIData( + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description, + outputOperationInfo.startTime, + outputOperationInfo.endTime, + outputOperationInfo.failureReason ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 78aeb004e18b1..f6cc6edf2569a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -119,6 +119,20 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = synchronized { + // This method is called after onBatchStarted + runningBatchUIData(outputOperationStarted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationStarted.outputOperationInfo) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = synchronized { + // This method is called before onBatchCompleted + runningBatchUIData(outputOperationCompleted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationCompleted.outputOperationInfo) + } + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 2b43b7467042b..5dc0472c7770c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global @@ -221,7 +221,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } _ssc.stop() - failureReasonsCollector.failureReasons + failureReasonsCollector.failureReasons.toMap } /** Check if a sequence of numbers is in increasing order */ @@ -307,14 +307,16 @@ class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_O } /** - * A StreamingListener that saves the latest `failureReasons` in `BatchInfo` to the `failureReasons` - * field. + * A StreamingListener that saves all latest `failureReasons` in a batch. */ class FailureReasonsCollector extends StreamingListener { - @volatile var failureReasons: Map[Int, String] = null + val failureReasons = new HashMap[Int, String] with SynchronizedMap[Int, String] - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { - failureReasons = batchCompleted.batchInfo.failureReasons + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + outputOperationCompleted.outputOperationInfo.failureReason.foreach { f => + failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index d1df78871d3b8..a5744a9009c1c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -117,7 +117,7 @@ class UISeleniumSuite findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", - "Status") + "Output Ops: Succeeded/Total", "Status") } findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 995f1197ccdfd..af4718b4eb705 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -63,7 +63,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -75,7 +75,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -116,7 +117,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -156,7 +158,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -173,8 +176,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // fulfill completedBatchInfos for(i <- 0 until limit) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) listener.onJobStart(jobStart) @@ -185,7 +188,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart) val batchInfoSubmitted = - BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None) + BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // We still can see the info retrieved from onJobStart @@ -201,8 +204,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // A lot of "onBatchCompleted"s happen before "onJobStart" for(i <- limit + 1 to limit * 2) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } @@ -227,11 +230,13 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -248,7 +253,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } From e1eef248f13f6c334fe4eea8a29a1de5470a2e62 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 16 Oct 2015 13:56:51 -0700 Subject: [PATCH 079/139] [SPARK-11104] [STREAMING] Fix a deadlock in StreamingContex.stop The following deadlock may happen if shutdownHook and StreamingContext.stop are running at the same time. ``` Java stack information for the threads listed above: =================================================== "Thread-2": at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:699) - waiting to lock <0x00000005405a1680> (a org.apache.spark.streaming.StreamingContext) at org.apache.spark.streaming.StreamingContext.org$apache$spark$streaming$StreamingContext$$stopOnShutdown(StreamingContext.scala:729) at org.apache.spark.streaming.StreamingContext$$anonfun$start$1.apply$mcV$sp(StreamingContext.scala:625) at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:266) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply$mcV$sp(ShutdownHookManager.scala:236) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:236) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:236) at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1697) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply$mcV$sp(ShutdownHookManager.scala:236) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:236) at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:236) at scala.util.Try$.apply(Try.scala:161) at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:236) - locked <0x00000005405b6a00> (a org.apache.spark.util.SparkShutdownHookManager) at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:216) at org.apache.hadoop.util.ShutdownHookManager$1.run(ShutdownHookManager.java:54) "main": at org.apache.spark.util.SparkShutdownHookManager.remove(ShutdownHookManager.scala:248) - waiting to lock <0x00000005405b6a00> (a org.apache.spark.util.SparkShutdownHookManager) at org.apache.spark.util.ShutdownHookManager$.removeShutdownHook(ShutdownHookManager.scala:199) at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:712) - locked <0x00000005405a1680> (a org.apache.spark.streaming.StreamingContext) at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:684) - locked <0x00000005405a1680> (a org.apache.spark.streaming.StreamingContext) at org.apache.spark.streaming.SessionByKeyBenchmark$.main(SessionByKeyBenchmark.scala:108) at org.apache.spark.streaming.SessionByKeyBenchmark.main(SessionByKeyBenchmark.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:497) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:680) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) ``` This PR just moved `ShutdownHookManager.removeShutdownHook` out of `synchronized` to avoid deadlock. Author: zsxwing Closes #9116 from zsxwing/stop-deadlock. --- .../spark/streaming/StreamingContext.scala | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 9b2632c229548..051f53de64cd5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -694,32 +694,39 @@ class StreamingContext private[streaming] ( * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => - scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) - waiter.notifyStop() - if (shutdownHookRef != null) { - ShutdownHookManager.removeShutdownHook(shutdownHookRef) - } - logInfo("StreamingContext stopped successfully") + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { + var shutdownHookRefToRemove: AnyRef = null + synchronized { + try { + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + case STOPPED => + logWarning("StreamingContext has already been stopped") + case ACTIVE => + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + } + } finally { + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state = STOPPED } - // Even if we have already stopped, we still need to attempt to stop the SparkContext because - // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). - if (stopSparkContext) sc.stop() - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED } + if (shutdownHookRefToRemove != null) { + ShutdownHookManager.removeShutdownHook(shutdownHookRefToRemove) + } + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). + if (stopSparkContext) sc.stop() } private def stopOnShutdown(): Unit = { From ac09a3a465f3b57f3964c5fd621ab0d2216e2354 Mon Sep 17 00:00:00 2001 From: gweidner Date: Fri, 16 Oct 2015 14:02:12 -0700 Subject: [PATCH 080/139] [SPARK-11109] [CORE] Move FsHistoryProvider off deprecated AccessControlException Switched from deprecated org.apache.hadoop.fs.permission.AccessControlException to org.apache.hadoop.security.AccessControlException. Author: gweidner Closes #9144 from gweidner/SPARK-11109. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 5eb8adf97d90b..80bfda9dddb39 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.fs.permission.AccessControlException +import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil From 1ec0a0dc2819d3db3555799cb78c2946f652bff4 Mon Sep 17 00:00:00 2001 From: Bhargav Mangipudi Date: Fri, 16 Oct 2015 14:36:05 -0700 Subject: [PATCH 081/139] =?UTF-8?q?[SPARK-11050]=20[MLLIB]=20PySpark=20Spa?= =?UTF-8?q?rseVector=20can=20return=20wrong=20index=20in=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rror message For negative indices in the SparseVector, we update the index value. If we have an incorrect index at this point, the error message has the incorrect *updated* index instead of the original one. This change contains the fix for the same. Author: Bhargav Mangipudi Closes #9069 from bhargav/spark-10759. --- python/pyspark/mllib/linalg/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index d903b9030d8ce..5276eb41cf29e 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -764,10 +764,11 @@ def __getitem__(self, index): if not isinstance(index, int): raise TypeError( "Indices must be of type integer, got type %s" % type(index)) + + if index >= self.size or index < -self.size: + raise ValueError("Index %d out of bounds." % index) if index < 0: index += self.size - if index >= self.size or index < 0: - raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) if insert_index >= inds.size: From 10046ea76cf8f0d08fe7ef548e4dbec69d9c73b8 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 16 Oct 2015 15:30:07 -0700 Subject: [PATCH 082/139] [SPARK-10599] [MLLIB] Lower communication for block matrix multiplication This PR aims to decrease communication costs in BlockMatrix multiplication in two ways: - Simulate the multiplication on the driver, and figure out which blocks actually need to be shuffled - Send the block once to a partition, and join inside the partition rather than sending multiple copies to the same partition **NOTE**: One important note is that right now, the old behavior of checking for multiple blocks with the same index is lost. This is not hard to add, but is a little more expensive than how it was. Initial benchmarking showed promising results (look below), however I did hit some `FileNotFound` exceptions with the new implementation after the shuffle. Size A: 1e5 x 1e5 Size B: 1e5 x 1e5 Block Sizes: 1024 x 1024 Sparsity: 0.01 Old implementation: 1m 13s New implementation: 9s cc avulanov Would you be interested in helping me benchmark this? I used your code from the mailing list (which you sent about 3 months ago?), and the old implementation didn't even run, but the new implementation completed in 268s in a 120 GB / 16 core cluster Author: Burak Yavuz Closes #8757 from brkyvz/opt-bmm. --- .../linalg/distributed/BlockMatrix.scala | 80 ++++++++++++++----- .../linalg/distributed/BlockMatrixSuite.scala | 18 +++++ 2 files changed, 76 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index a33b6137cf9cc..81a6c0550bda7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -54,12 +54,14 @@ private[mllib] class GridPartitioner( /** * Returns the index of the partition the input coordinate belongs to. * - * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in - * multiplication. k is ignored in computing partitions. + * @param key The partition id i (calculated through this method for coordinate (i, j) in + * `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is + * the inner index used in multiplication. k is ignored in computing partitions. * @return The index of the partition, which the coordinate belongs to. */ override def getPartition(key: Any): Int = { key match { + case i: Int => i case (i: Int, j: Int) => getPartitionId(i, j) case (i: Int, j: Int, _: Int) => @@ -352,12 +354,49 @@ class BlockMatrix @Since("1.3.0") ( } } + /** Block (i,j) --> Set of destination partitions */ + private type BlockDestinations = Map[(Int, Int), Set[Int]] + + /** + * Simulate the multiplication with just block indices in order to cut costs on communication, + * when we are actually shuffling the matrices. + * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`. + * Exposed for tests. + * + * @param other The BlockMatrix to multiply + * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B` + * @return A tuple of [[BlockDestinations]]. The first element is the Map of the set of partitions + * that we need to shuffle each blocks of `this`, and the second element is the Map for + * `other`. + */ + private[distributed] def simulateMultiply( + other: BlockMatrix, + partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { + val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached + val rightMatrix = other.blocks.keys.collect() + val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => + val rightCounterparts = rightMatrix.filter(_._1 == colIndex) + val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => + val leftCounterparts = leftMatrix.filter(_._2 == rowIndex) + val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + (leftDestinations, rightDestinations) + } + /** * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. + * + * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added + * with each other. */ @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { @@ -368,33 +407,30 @@ class BlockMatrix @Since("1.3.0") ( if (colsPerBlock == other.rowsPerBlock) { val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, math.max(blocks.partitions.length, other.blocks.partitions.length)) - // Each block of A must be multiplied with the corresponding blocks in each column of B. - // TODO: Optimize to send block to a partition once, similar to ALS + val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner) + // Each block of A must be multiplied with the corresponding blocks in the columns of B. val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } // Each block of B must be multiplied with the corresponding blocks in each row of A. val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } - val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) - .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => - if (a.size > 1 || b.size > 1) { - throw new SparkException("There are multiple MatrixBlocks with indices: " + - s"($blockRowIndex, $blockColIndex). Please remove them.") - } - if (a.nonEmpty && b.nonEmpty) { - val C = b.head match { - case dense: DenseMatrix => a.head.multiply(dense) - case sparse: SparseMatrix => a.head.multiply(sparse.toDense) - case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) => + a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) => + b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) => + val C = rightBlock match { + case dense: DenseMatrix => leftBlock.multiply(dense) + case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense) + case _ => + throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.") } - Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) - } else { - Iterator() + ((leftRowIndex, rightColIndex), C.toBreeze) } - }.reduceByKey(resultPartitioner, (a, b) => a + b) - .mapValues(Matrices.fromBreeze) + } + }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze) // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 93fe04c139b9a..b8eb10305801c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -235,6 +235,24 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(localC ~== result absTol 1e-8) } + test("simulate multiply") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks, + math.max(numPartitions, 2)) + val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner) + assert(destinationsA((0, 0)) === Set(0)) + assert(destinationsA((0, 1)) === Set(2)) + assert(destinationsA((1, 0)) === Set(0)) + assert(destinationsA((1, 1)) === Set(2)) + assert(destinationsA((2, 1)) === Set(3)) + assert(destinationsB((0, 0)) === Set(0)) + assert(destinationsB((1, 1)) === Set(2, 3)) + } + test("validate") { // No error gridBasedMat.validate() From 8ac71d62d976bbfd0159cac6816dd8fa580ae1cb Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 16 Oct 2015 15:53:26 -0700 Subject: [PATCH 083/139] [SPARK-11084] [ML] [PYTHON] Check if index can contain non-zero value before binary search At this moment `SparseVector.__getitem__` executes `np.searchsorted` first and checks if result is in an expected range after that. It is possible to check if index can contain non-zero value before executing `np.searchsorted`. Author: zero323 Closes #9098 from zero323/sparse_vector_getitem_improved. --- python/pyspark/mllib/linalg/__init__.py | 4 ++-- python/pyspark/mllib/tests.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 5276eb41cf29e..ae9ce58450905 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -770,10 +770,10 @@ def __getitem__(self, index): if index < 0: index += self.size - insert_index = np.searchsorted(inds, index) - if insert_index >= inds.size: + if (inds.size == 0) or (index > inds.item(-1)): return 0. + insert_index = np.searchsorted(inds, index) row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 2a6a5cd3fe40e..2ad69a0ab1d3d 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -252,6 +252,16 @@ def test_sparse_vector_indexing(self): for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(ValueError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(ValueError, empty.__getitem__, ind) + def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) expected = [[0, 6], [1, 8], [4, 10]] From e1e77b22b3b577909a12c3aa898eb53be02267fd Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 17 Oct 2015 10:04:19 -0700 Subject: [PATCH 084/139] [SPARK-11029] [ML] Add computeCost to KMeansModel in spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11029 We should add a method analogous to spark.mllib.clustering.KMeansModel.computeCost to spark.ml.clustering.KMeansModel. This will be a temp fix until we have proper evaluators defined for clustering. Author: Yuhao Yang Author: yuhaoyang Closes #9073 from hhbyyh/computeCost. --- .../org/apache/spark/ml/clustering/KMeans.scala | 12 ++++++++++++ .../org/apache/spark/ml/clustering/KMeansSuite.scala | 1 + 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f40ab71fb22a6..509be63002396 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -117,6 +117,18 @@ class KMeansModel private[ml] ( @Since("1.5.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @Since("1.6.0") + def computeCost(dataset: DataFrame): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + parentModel.computeCost(data) + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 688b0e31f91dc..c05f90550d161 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -104,5 +104,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) } } From cca2258685147be6c950c9f5c4e50eaa1e090714 Mon Sep 17 00:00:00 2001 From: Luvsandondov Lkhamsuren Date: Sat, 17 Oct 2015 10:07:42 -0700 Subject: [PATCH 085/139] [SPARK-9963] [ML] RandomForest cleanup: replace predictNodeIndex with predictImpl predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl Author: Luvsandondov Lkhamsuren Closes #8609 from lkhamsurenl/SPARK-9963. --- .../scala/org/apache/spark/ml/tree/Node.scala | 37 ++++++++++++++++ .../spark/ml/tree/impl/RandomForest.scala | 44 +------------------ 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index cd24931293903..d89682611e3f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -279,6 +279,43 @@ private[tree] class LearningNode( } } + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. + * + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to [[findBestSplits()]]. + */ + def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { + if (this.isLeaf || this.split.isEmpty) { + this.id + } else { + val split = this.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (this.leftChild.isEmpty) { + // Not yet split. Return next layer of nodes to train + if (splitLeft) { + LearningNode.leftChildIndex(this.id) + } else { + LearningNode.rightChildIndex(this.id) + } + } else { + if (splitLeft) { + this.leftChild.get.predictImpl(binnedFeatures, splits) + } else { + this.rightChild.get.predictImpl(binnedFeatures, splits) + } + } + } + } + } private[tree] object LearningNode { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index c494556085e95..96d5652857e08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging { } } - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a leaf - * or unsplit node; that node's index is returned. - * - * @param node Node in tree from which to classify the given data point. - * @param binnedFeatures Binned feature vector for data point. - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * group of nodes on one call to [[findBestSplits()]]. - */ - private def predictNodeIndex( - node: LearningNode, - binnedFeatures: Array[Int], - splits: Array[Array[Split]]): Int = { - if (node.isLeaf || node.split.isEmpty) { - node.id - } else { - val split = node.split.get - val featureIndex = split.featureIndex - val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) - if (node.leftChild.isEmpty) { - // Not yet split. Return index from next layer of nodes to train - if (splitLeft) { - LearningNode.leftChildIndex(node.id) - } else { - LearningNode.rightChildIndex(node.id) - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftChild.get, binnedFeatures, splits) - } else { - predictNodeIndex(node.rightChild.get, binnedFeatures, splits) - } - } - } - } - /** * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. * @@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits) + val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg From 254937420678a299f06b6f4e2696c623da56cf3a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 17 Oct 2015 12:41:42 -0700 Subject: [PATCH 086/139] [SPARK-11165] Logging trait should be private - not DeveloperApi. Its classdoc actually says; "NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility." Author: Reynold Xin Closes #9155 from rxin/private-logging-trait. --- core/src/main/scala/org/apache/spark/Logging.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index f0598816d6c07..69f6e06ee0057 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -21,11 +21,10 @@ import org.apache.log4j.{LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Private import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. @@ -33,7 +32,7 @@ import org.apache.spark.util.Utils * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility. * This will likely be changed or removed in future releases. */ -@DeveloperApi +@Private trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine From 57f83e36d63bbd79663c49a6c1e8f6c3c8fe4789 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sat, 17 Oct 2015 14:56:24 -0700 Subject: [PATCH 087/139] [SPARK-10185] [SQL] Feat sql comma separated paths Make sure comma-separated paths get processed correcly in ResolvedDataSource for a HadoopFsRelationProvider Author: Koert Kuipers Closes #8416 from koertkuipers/feat-sql-comma-separated-paths. --- python/pyspark/sql/readwriter.py | 14 +++++- python/test_support/sql/people1.json | 2 + .../apache/spark/sql/DataFrameReader.scala | 11 +++++ .../datasources/ResolvedDataSource.scala | 47 +++++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 18 +++++++ 5 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 python/test_support/sql/people1.json diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f43d8bf646a9e..93832d4c713e5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -116,6 +116,10 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', + ... 'python/test_support/sql/people1.json']) + >>> df.dtypes + [('age', 'bigint'), ('aka', 'string'), ('name', 'string')] """ if format is not None: self.format(format) @@ -123,7 +127,15 @@ def load(self, path=None, format=None, schema=None, **options): self.schema(schema) self.options(**options) if path is not None: - return self._df(self._jreader.load(path)) + if type(path) == list: + paths = path + gateway = self._sqlContext._sc._gateway + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) + for i in range(0, len(paths)): + jpaths[i] = paths[i] + return self._df(self._jreader.load(jpaths)) + else: + return self._df(self._jreader.load(path)) else: return self._df(self._jreader.load()) diff --git a/python/test_support/sql/people1.json b/python/test_support/sql/people1.json new file mode 100644 index 0000000000000..6d217da77d155 --- /dev/null +++ b/python/test_support/sql/people1.json @@ -0,0 +1,2 @@ +{"name":"Jonathan", "aka": "John"} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index eacdea2c1e5b3..e8651a3569d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.StringUtils import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -123,6 +124,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { DataFrame(sqlContext, LogicalRelation(resolved.relation)) } + /** + * Loads input in as a [[DataFrame]], for data sources that support multiple paths. + * Only works if the source is a HadoopFsRelationProvider. + * + * @since 1.6.0 + */ + def load(paths: Array[String]): DataFrame = { + option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() + } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 011724436621d..54beabbf63b5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -24,6 +24,7 @@ import scala.language.{existentials, implicitConversions} import scala.util.{Success, Failure, Try} import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil @@ -89,7 +90,11 @@ object ResolvedDataSource extends Logging { val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + if (caseInsensitiveOptions.contains("paths")) { + throw new AnalysisException(s"$className does not support paths option.") + } + dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema) case dataSource: HadoopFsRelationProvider => val maybePartitionsSchema = if (partitionColumns.isEmpty) { None @@ -99,10 +104,19 @@ object ResolvedDataSource extends Logging { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + if (caseInsensitiveOptions.contains("paths") && + caseInsensitiveOptions.contains("path")) { + throw new AnalysisException(s"Both path and paths options are present.") + } + caseInsensitiveOptions.get("paths") + .map(_.split("(? + val hdfsPath = new Path(pathString) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) + } } val dataSchema = @@ -122,14 +136,27 @@ object ResolvedDataSource extends Logging { case None => clazz.newInstance() match { case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + if (caseInsensitiveOptions.contains("paths")) { + throw new AnalysisException(s"$className does not support paths option.") + } + dataSource.createRelation(sqlContext, caseInsensitiveOptions) case dataSource: HadoopFsRelationProvider => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + if (caseInsensitiveOptions.contains("paths") && + caseInsensitiveOptions.contains("path")) { + throw new AnalysisException(s"Both path and paths options are present.") + } + caseInsensitiveOptions.get("paths") + .map(_.split("(? + val hdfsPath = new Path(pathString) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) + } } dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d919877746c72..832ea02cb6e77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -890,6 +890,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .collect() } + test("SPARK-10185: Read multiple Hadoop Filesystem paths and paths with a comma in it") { + withTempDir { dir => + val df1 = Seq((1, 22)).toDF("a", "b") + val dir1 = new File(dir, "dir,1").getCanonicalPath + df1.write.format("json").save(dir1) + + val df2 = Seq((2, 23)).toDF("a", "b") + val dir2 = new File(dir, "dir2").getCanonicalPath + df2.write.format("json").save(dir2) + + checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + Row(1, 22) :: Row(2, 23) :: Nil) + + checkAnswer(sqlContext.read.format("json").load(dir1), + Row(1, 22) :: Nil) + } + } + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { val df = Seq(1 -> 2).toDF("i", "j") val query = df.groupBy('i) From 022a8f6a1f7cb477a65a65482982c021ce08a73c Mon Sep 17 00:00:00 2001 From: ph Date: Sat, 17 Oct 2015 15:37:51 -0700 Subject: [PATCH 088/139] [SPARK-11129] [MESOS] Link Spark WebUI from Mesos WebUI Mesos has a feature for linking to frameworks running on top of Mesos from the Mesos WebUI. This commit enables Spark to make use of this feature so one can directly visit the running Spark WebUIs from the Mesos WebUI. Author: ph Closes #9135 from philipphoffmann/SPARK-11129. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 7 ++++++- .../scheduler/cluster/mesos/MesosSchedulerBackend.scala | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 65cb5016cfcc9..d10a77f8e5c78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -127,7 +127,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() val driver = createSchedulerDriver( - master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + CoarseMesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8edf7007a5daf..6196176c7cc33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -68,7 +68,12 @@ private[spark] class MesosSchedulerBackend( override def start() { classLoader = Thread.currentThread.getContextClassLoader val driver = createSchedulerDriver( - master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + MesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } From e2dfdbb2c0523517880138f214775f9a896f2271 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Sat, 17 Oct 2015 16:41:49 -0700 Subject: [PATCH 089/139] [SPARK-11000] [YARN] Load `metadata.Hive` class only when `hive.metastore.uris` was set to avoid bootting the database twice Author: huangzhaowei Closes #9026 from SaintBacchus/SPARK-11000. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 9fcfe362a3ba2..08aecfa7f6fe0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1327,11 +1327,8 @@ object Client extends Logging { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hive = hiveClass.getMethod("get").invoke(null) - - val hiveConf = hiveClass.getMethod("getConf").invoke(hive) val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + val hiveConf = hiveConfClass.newInstance() val hiveConfGet = (param: String) => Option(hiveConfClass .getMethod("get", classOf[java.lang.String]) @@ -1341,6 +1338,9 @@ object Client extends Logging { // Check for local metastore if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null, hiveConf.asInstanceOf[Object]) + val metastore_kerberos_principal_conf_var = mirror.classLoader .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString From 3895b2113a726171b3c9c04fe41b3cc93d6d14b5 Mon Sep 17 00:00:00 2001 From: tedyu Date: Sun, 18 Oct 2015 02:12:56 -0700 Subject: [PATCH 090/139] [SPARK-11172] Close JsonParser/Generator in test Author: tedyu Closes #9157 from tedyu/master. --- .../sql/execution/datasources/json/JsonSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b614e6c4148fd..7540223bf2771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -47,13 +47,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val factory = new JsonFactory() def enforceCorrectType(value: Any, dataType: DataType): Any = { val writer = new StringWriter() - val generator = factory.createGenerator(writer) - generator.writeObject(value) - generator.flush() + Utils.tryWithResource(factory.createGenerator(writer)) { generator => + generator.writeObject(value) + generator.flush() + } - val parser = factory.createParser(writer.toString) - parser.nextToken() - JacksonParser.convertField(factory, parser, dataType) + Utils.tryWithResource(factory.createParser(writer.toString)) { parser => + parser.nextToken() + JacksonParser.convertField(factory, parser, dataType) + } } val intNumber: Int = 2147483647 From a112d69fdcd9f6d8805be6e0bc6d2211e26867c2 Mon Sep 17 00:00:00 2001 From: Lukasz Piepiora Date: Sun, 18 Oct 2015 14:25:57 +0100 Subject: [PATCH 091/139] [SPARK-11174] [DOCS] Fix typo in the GraphX programming guide This patch fixes a small typo in the GraphX programming guide Author: Lukasz Piepiora Closes #9160 from lpiepiora/11174-fix-typo-in-graphx-programming-guide. --- docs/graphx-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index c861a763d6222..6a512ab234bb2 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -944,7 +944,7 @@ The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] -// Revere the edges reusing both attributes and structure +// Reverse the edges reusing both attributes and structure def reverse: EdgeRDD[ED] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] From 0480d6ca83d170618fa6a817ad64a2872438d47f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 18 Oct 2015 09:54:38 -0700 Subject: [PATCH 092/139] [SPARK-11169] Remove the extra spaces in merge script Our merge script now turns ``` [SPARK-1234][SPARK-1235][SPARK-1236][SQL] description ``` into ``` [SPARK-1234] [SPARK-1235] [SPARK-1236] [SQL] description ``` The extra spaces are more annoying in git since the first line of a git commit is supposed to be very short. Doctest passes with the following command: ``` python -m doctest merge_spark_pr.py ``` Author: Reynold Xin Closes #9156 from rxin/SPARK-11169. --- dev/merge_spark_pr.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index b9bdec3d70864..bf1a000f46791 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -300,24 +300,24 @@ def resolve_jira_issues(title, merge_branches, comment): def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix - Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" + Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX][MLLIB] Issue" >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") - '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' + '[SPARK-5821][SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") - '[SPARK-4123] [PROJECT INFRA] [WIP] Show new dependencies added in pull requests' + '[SPARK-4123][PROJECT INFRA][WIP] Show new dependencies added in pull requests' >>> standardize_jira_ref("[MLlib] Spark 5954: Top by key") - '[SPARK-5954] [MLLIB] Top by key' + '[SPARK-5954][MLLIB] Top by key' >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' >>> standardize_jira_ref("SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") - '[SPARK-1146] [WIP] Vagrant support for Spark' + '[SPARK-1146][WIP] Vagrant support for Spark' >>> standardize_jira_ref("SPARK-1032. If Yarn app fails before registering, app master stays aroun...") '[SPARK-1032] If Yarn app fails before registering, app master stays aroun...' >>> standardize_jira_ref("[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.") - '[SPARK-6250] [SPARK-6146] [SPARK-5911] [SQL] Types are now reserved words in DDL parser.' + '[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.' >>> standardize_jira_ref("Additional information for users building from source code") 'Additional information for users building from source code' """ @@ -325,7 +325,7 @@ def standardize_jira_ref(text): components = [] # If the string is compliant, no need to process any further - if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): + if (re.search(r'^\[SPARK-[0-9]{3,6}\](\[[A-Z0-9_\s,]+\] )+\S+', text)): return text # Extract JIRA ref(s): @@ -348,7 +348,7 @@ def standardize_jira_ref(text): text = pattern.search(text).groups()[0] # Assemble full text (JIRA ref(s), module(s), remaining text) - clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() + clean_text = ''.join(jira_refs).strip() + ''.join(components).strip() + " " + text.strip() # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) From 8d4449c7f5d528410306c288a042c4594b81a881 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 18 Oct 2015 10:36:50 -0700 Subject: [PATCH 093/139] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #8737 (close requested by 'srowen') Closes #5323 (close requested by 'JoshRosen') Closes #6148 (close requested by 'JoshRosen') Closes #7557 (close requested by 'JoshRosen') Closes #7047 (close requested by 'srowen') Closes #8713 (close requested by 'marmbrus') Closes #5834 (close requested by 'srowen') Closes #7467 (close requested by 'tdas') Closes #8943 (close requested by 'xiaowen147') Closes #4434 (close requested by 'JoshRosen') Closes #8949 (close requested by 'srowen') Closes #5368 (close requested by 'JoshRosen') Closes #8186 (close requested by 'marmbrus') Closes #5147 (close requested by 'JoshRosen') From a337c235a12d4ea6a7d6db457acc6b32f1915241 Mon Sep 17 00:00:00 2001 From: Mahmoud Lababidi Date: Sun, 18 Oct 2015 11:39:19 -0700 Subject: [PATCH 094/139] [SPARK-11158][SQL] Modified _verify_type() to be more informative on Errors by presenting the Object The _verify_type() function had Errors that were raised when there were Type conversion issues but left out the Object in question. The Object is now added in the Error to reduce the strain on the user to debug through to figure out the Object that failed the Type conversion. The use case for me was a Pandas DataFrame that contained 'nan' as values for columns of Strings. Author: Mahmoud Lababidi Author: Mahmoud Lababidi Closes #9149 from lababidi/master. --- python/pyspark/sql/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f86894855cbe..5bc0773fa8660 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1127,15 +1127,15 @@ def _verify_type(obj, dataType): return _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s" % dataType + assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) if _type is StructType: if not isinstance(obj, (tuple, list)): - raise TypeError("StructType can not accept object in type %s" % type(obj)) + raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) else: # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) + raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) if isinstance(dataType, ArrayType): for i in obj: From 94c8fef296e5cdac9a93ed34acc079e51839caa7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 18 Oct 2015 13:51:45 -0700 Subject: [PATCH 095/139] [SPARK-11126][SQL] Fix a memory leak in SQLListener._stageIdToStageMetrics SQLListener adds all stage infos to `_stageIdToStageMetrics`, but only removes stage infos belonging to SQL executions. This PR fixed it by ignoring stages that don't belong to SQL executions. Reported by Terry Hoo in https://www.mail-archive.com/userspark.apache.org/msg38810.html Author: zsxwing Closes #9132 from zsxwing/SPARK-11126. --- .../spark/sql/execution/ui/SQLListener.scala | 8 +++++++- .../sql/execution/ui/SQLListenerSuite.scala | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index b302b519998ac..5a072de400b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -126,7 +126,13 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi val stageId = stageSubmitted.stageInfo.stageId val stageAttemptId = stageSubmitted.stageInfo.attemptId // Always override metrics for old stage attempt - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + if (_stageIdToStageMetrics.contains(stageId)) { + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } else { + // If a stage belongs to some SQL execution, its stageId will be put in "onJobStart". + // Since "_stageIdToStageMetrics" doesn't contain it, it must not belong to any SQL execution. + // So we can ignore it. Otherwise, this may lead to memory leaks (SPARK-11126). + } } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index cc1c1e10e98c4..03bcee94a2b91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -313,7 +313,22 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(executionUIData.failedJobs === Seq(0)) } - ignore("no memory leak") { + test("SPARK-11126: no memory leak when running non SQL jobs") { + val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size + sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) + // listener should ignore the non SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + + sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + // listener should save the SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + } + +} + +class SQLListenerMemoryLeakSuite extends SparkFunSuite { + + test("no memory leak") { val conf = new SparkConf() .setMaster("local") .setAppName("test") @@ -348,5 +363,4 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { sc.stop() } } - } From d3180c25d8cf0899a7238e7d24b35c5ae918cc1d Mon Sep 17 00:00:00 2001 From: Brennon York Date: Sun, 18 Oct 2015 22:45:14 -0700 Subject: [PATCH 096/139] [SPARK-7018][BUILD] Refactor dev/run-tests-jenkins into Python This commit refactors the `run-tests-jenkins` script into Python. This refactoring was done by brennonyork in #7401; this PR contains a few minor edits from joshrosen in order to bring it up to date with other recent changes. From the original PR description (by brennonyork): Currently a few things are left out that, could and I think should, be smaller JIRA's after this. 1. There are still a few areas where we use environment variables where we don't need to (like `CURRENT_BLOCK`). I might get around to fixing this one in lieu of everything else, but wanted to point that out. 2. The PR tests are still written in bash. I opted to not change those and just rewrite the runner into Python. This is a great follow-on JIRA IMO. 3. All of the linting scripts are still in bash as well and would likely do to just add those in as follow-on JIRA's as well. Closes #7401. Author: Brennon York Closes #9161 from JoshRosen/run-tests-jenkins-refactoring. --- dev/lint-python | 2 +- dev/run-tests-codes.sh | 30 ---- dev/run-tests-jenkins | 204 +------------------------- dev/run-tests-jenkins.py | 228 +++++++++++++++++++++++++++++ dev/run-tests.py | 20 +-- dev/sparktestsupport/__init__.py | 14 ++ dev/sparktestsupport/shellutils.py | 37 ++++- python/run-tests.py | 19 +-- 8 files changed, 285 insertions(+), 269 deletions(-) delete mode 100644 dev/run-tests-codes.sh create mode 100755 dev/run-tests-jenkins.py diff --git a/dev/lint-python b/dev/lint-python index 575dbb0ae321b..0b97213ae3dff 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh deleted file mode 100644 index 1f16790522e76..0000000000000 --- a/dev/run-tests-codes.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -readonly BLOCK_GENERAL=10 -readonly BLOCK_RAT=11 -readonly BLOCK_SCALA_STYLE=12 -readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_R_STYLE=14 -readonly BLOCK_DOCUMENTATION=15 -readonly BLOCK_BUILD=16 -readonly BLOCK_MIMA=17 -readonly BLOCK_SPARK_UNIT_TESTS=18 -readonly BLOCK_PYSPARK_UNIT_TESTS=19 -readonly BLOCK_SPARKR_UNIT_TESTS=20 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index d3b05fa6df0ce..e79accf9e987a 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,207 +22,7 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -# Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -source "$FWDIR/dev/run-tests-codes.sh" - -COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" -PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" - -# Important Environment Variables -# --- -# $ghprbActualCommit -#+ This is the hash of the most recent commit in the PR. -#+ The merge-base of this and master is the commit from which the PR was branched. -# $sha1 -#+ If the patch merges cleanly, this is a reference to the merge commit hash -#+ (e.g. "origin/pr/2606/merge"). -#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. -#+ The merge-base of this and master in the case of a clean merge is the most recent commit -#+ against master. - -COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" -# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( -SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" - -# format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 300m) -TESTS_TIMEOUT="250m" - -# Array to capture all tests to run on the pull request. These tests are held under the -#+ dev/tests/ directory. -# -# To write a PR test: -#+ * the file must reside within the dev/tests directory -#+ * be an executable bash script -#+ * accept three arguments on the command line, the first being the Github PR long commit -#+ hash, the second the Github SHA1 hash, and the final the current PR hash -#+ * and, lastly, return string output to be included in the pr message output that will -#+ be posted to Github -PR_TESTS=( - "pr_merge_ability" - "pr_public_classes" -# DISABLED (pwendell) "pr_new_dependencies" -) - -function post_message () { - local message=$1 - local data="{\"body\": \"$message\"}" - local HTTP_CODE_HEADER="HTTP Response Code: " - - echo "Attempting to post to Github..." - - local curl_output=$( - curl `#--dump-header -` \ - --silent \ - --user x-oauth-basic:$GITHUB_OAUTH_KEY \ - --request POST \ - --data "$data" \ - --write-out "${HTTP_CODE_HEADER}%{http_code}\n" \ - --header "Content-Type: application/json" \ - "$COMMENTS_URL" #> /dev/null #| "$FWDIR/dev/jq" .id #| head -n 8 - ) - local curl_status=${PIPESTATUS[0]} - - if [ "$curl_status" -ne 0 ]; then - echo "Failed to post message to GitHub." >&2 - echo " > curl_status: ${curl_status}" >&2 - echo " > curl_output: ${curl_output}" >&2 - echo " > data: ${data}" >&2 - # exit $curl_status - fi - - local api_response=$( - echo "${curl_output}" \ - | grep -v -e "^${HTTP_CODE_HEADER}" - ) - - local http_code=$( - echo "${curl_output}" \ - | grep -e "^${HTTP_CODE_HEADER}" \ - | sed -r -e "s/^${HTTP_CODE_HEADER}//g" - ) - - if [ -n "$http_code" ] && [ "$http_code" -ne "201" ]; then - echo " > http_code: ${http_code}." >&2 - echo " > api_response: ${api_response}" >&2 - echo " > data: ${data}" >&2 - fi - - if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then - echo " > Post successful." - fi -} - -# post start message -{ - start_message="\ - [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - post_message "$start_message" -} - -# Environment variable to capture PR test output -pr_message="" -# Ensure we save off the current HEAD to revert to -current_pr_head="`git rev-parse HEAD`" - -echo "HEAD: `git rev-parse HEAD`" -echo "\$ghprbActualCommit: $ghprbActualCommit" -echo "\$sha1: $sha1" -echo "\$ghprbPullTitle: $ghprbPullTitle" - -# Run pull request tests -for t in "${PR_TESTS[@]}"; do - this_test="${FWDIR}/dev/tests/${t}.sh" - # Ensure the test can be found and is a file - if [ -f "${this_test}" ]; then - echo "Running test: $t" - this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" - # Check if this is the merge test as we submit that note *before* and *after* - # the tests run - [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" - pr_message="${pr_message}\n${this_mssg}" - # Ensure, after each test, that we're back on the current PR - git checkout -f "${current_pr_head}" &>/dev/null - else - echo "Cannot find test ${this_test}." - fi -done - -# run tests -{ - # Marks this build is a pull request build. - export AMP_JENKINS_PRB=true - if [[ $ghprbPullTitle == *"test-maven"* ]]; then - export AMPLAB_JENKINS_BUILD_TOOL="maven" - fi - if [[ $ghprbPullTitle == *"test-hadoop1.0"* ]]; then - export AMPLAB_JENKINS_BUILD_PROFILE="hadoop1.0" - elif [[ $ghprbPullTitle == *"test-hadoop2.0"* ]]; then - export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.0" - elif [[ $ghprbPullTitle == *"test-hadoop2.2"* ]]; then - export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.2" - elif [[ $ghprbPullTitle == *"test-hadoop2.3"* ]]; then - export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.3" - fi - - timeout "${TESTS_TIMEOUT}" ./dev/run-tests - test_result="$?" - - if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ - for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ - after a configured wait of \`${TESTS_TIMEOUT}\`." - - post_message "$fail_message" - exit $test_result - elif [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes all tests**." - else - if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then - failing_test="some tests" - elif [ "$test_result" -eq "$BLOCK_RAT" ]; then - failing_test="RAT tests" - elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then - failing_test="Scala style tests" - elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then - failing_test="Python style tests" - elif [ "$test_result" -eq "$BLOCK_R_STYLE" ]; then - failing_test="R style tests" - elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then - failing_test="to generate documentation" - elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then - failing_test="to build" - elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then - failing_test="MiMa tests" - elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then - failing_test="Spark unit tests" - elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then - failing_test="PySpark unit tests" - elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then - failing_test="SparkR unit tests" - else - failing_test="some tests" - fi - - test_result_note=" * This patch **fails $failing_test**." - fi -} - -# post end message -{ - result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - result_message="${result_message}\n${test_result_note}" - result_message="${result_message}${pr_message}" - - post_message "$result_message" -} - -exit $test_result +exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py new file mode 100755 index 0000000000000..623004310e189 --- /dev/null +++ b/dev/run-tests-jenkins.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python2 + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import os +import sys +import json +import urllib2 +import functools +import subprocess + +from sparktestsupport import SPARK_HOME, ERROR_CODES +from sparktestsupport.shellutils import run_cmd + + +def print_err(msg): + """ + Given a set of arguments, will print them to the STDERR stream + """ + print(msg, file=sys.stderr) + + +def post_message_to_github(msg, ghprb_pull_id): + print("Attempting to post to Github...") + + url = "https://api.github.com/repos/apache/spark/issues/" + ghprb_pull_id + "/comments" + github_oauth_key = os.environ["GITHUB_OAUTH_KEY"] + + posted_message = json.dumps({"body": msg}) + request = urllib2.Request(url, + headers={ + "Authorization": "token %s" % github_oauth_key, + "Content-Type": "application/json" + }, + data=posted_message) + try: + response = urllib2.urlopen(request) + + if response.getcode() == 201: + print(" > Post successful.") + except urllib2.HTTPError as http_e: + print_err("Failed to post message to Github.") + print_err(" > http_code: %s" % http_e.code) + print_err(" > api_response: %s" % http_e.read()) + print_err(" > data: %s" % posted_message) + except urllib2.URLError as url_e: + print_err("Failed to post message to Github.") + print_err(" > urllib2_status: %s" % url_e.reason[1]) + print_err(" > data: %s" % posted_message) + + +def pr_message(build_display_name, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url, + msg, + post_msg=''): + # align the arguments properly for string formatting + str_args = (build_display_name, + msg, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url, + str(' ' + post_msg + '.') if post_msg else '.') + return '**[Test build %s %s](%sconsoleFull)** for PR %s at commit [`%s`](%s)%s' % str_args + + +def run_pr_checks(pr_tests, ghprb_actual_commit, sha1): + """ + Executes a set of pull request checks to ease development and report issues with various + components such as style, linting, dependencies, compatibilities, etc. + @return a list of messages to post back to Github + """ + # Ensure we save off the current HEAD to revert to + current_pr_head = run_cmd(['git', 'rev-parse', 'HEAD'], return_output=True).strip() + pr_results = list() + + for pr_test in pr_tests: + test_name = pr_test + '.sh' + pr_results.append(run_cmd(['bash', os.path.join(SPARK_HOME, 'dev', 'tests', test_name), + ghprb_actual_commit, sha1], + return_output=True).rstrip()) + # Ensure, after each test, that we're back on the current PR + run_cmd(['git', 'checkout', '-f', current_pr_head]) + return pr_results + + +def run_tests(tests_timeout): + """ + Runs the `dev/run-tests` script and responds with the correct error message + under the various failure scenarios. + @return a tuple containing the test result code and the result note to post to Github + """ + + test_result_code = subprocess.Popen(['timeout', + tests_timeout, + os.path.join(SPARK_HOME, 'dev', 'run-tests')]).wait() + + failure_note_by_errcode = { + 1: 'executing the `dev/run-tests` script', # error to denote run-tests script failures + ERROR_CODES["BLOCK_GENERAL"]: 'some tests', + ERROR_CODES["BLOCK_RAT"]: 'RAT tests', + ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', + ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', + ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', + ERROR_CODES["BLOCK_BUILD"]: 'to build', + ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', + ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', + ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', + ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( + tests_timeout) + } + + if test_result_code == 0: + test_result_note = ' * This patch passes all tests.' + else: + test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] + + return [test_result_code, test_result_note] + + +def main(): + # Important Environment Variables + # --- + # $ghprbActualCommit + # This is the hash of the most recent commit in the PR. + # The merge-base of this and master is the commit from which the PR was branched. + # $sha1 + # If the patch merges cleanly, this is a reference to the merge commit hash + # (e.g. "origin/pr/2606/merge"). + # If the patch does not merge cleanly, it is equal to $ghprbActualCommit. + # The merge-base of this and master in the case of a clean merge is the most recent commit + # against master. + ghprb_pull_id = os.environ["ghprbPullId"] + ghprb_actual_commit = os.environ["ghprbActualCommit"] + ghprb_pull_title = os.environ["ghprbPullTitle"] + sha1 = os.environ["sha1"] + + # Marks this build as a pull request build. + os.environ["AMP_JENKINS_PRB"] = "true" + # Switch to a Maven-based build if the PR title contains "test-maven": + if "test-maven" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_TOOL"] = "maven" + # Switch the Hadoop profile based on the PR title: + if "test-hadoop1.0" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" + if "test-hadoop2.2" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" + if "test-hadoop2.2" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" + if "test-hadoop2.3" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.3" + + build_display_name = os.environ["BUILD_DISPLAY_NAME"] + build_url = os.environ["BUILD_URL"] + + commit_url = "https://github.com/apache/spark/commit/" + ghprb_actual_commit + + # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( + short_commit_hash = ghprb_actual_commit[0:7] + + # format: http://linux.die.net/man/1/timeout + # must be less than the timeout configured on Jenkins (currently 300m) + tests_timeout = "250m" + + # Array to capture all test names to run on the pull request. These tests are represented + # by their file equivalents in the dev/tests/ directory. + # + # To write a PR test: + # * the file must reside within the dev/tests directory + # * be an executable bash script + # * accept three arguments on the command line, the first being the Github PR long commit + # hash, the second the Github SHA1 hash, and the final the current PR hash + # * and, lastly, return string output to be included in the pr message output that will + # be posted to Github + pr_tests = [ + "pr_merge_ability", + "pr_public_classes" + # DISABLED (pwendell) "pr_new_dependencies" + ] + + # `bind_message_base` returns a function to generate messages for Github posting + github_message = functools.partial(pr_message, + build_display_name, + build_url, + ghprb_pull_id, + short_commit_hash, + commit_url) + + # post start message + post_message_to_github(github_message('has started'), ghprb_pull_id) + + pr_check_results = run_pr_checks(pr_tests, ghprb_actual_commit, sha1) + + test_result_code, test_result_note = run_tests(tests_timeout) + + # post end message + result_message = github_message('has finished') + result_message += '\n' + test_result_note + '\n' + result_message += '\n'.join(pr_check_results) + + post_message_to_github(result_message, ghprb_pull_id) + + sys.exit(test_result_code) + + +if __name__ == "__main__": + main() diff --git a/dev/run-tests.py b/dev/run-tests.py index d4d6880491bc8..6b4b71073453d 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -27,10 +27,11 @@ import subprocess from collections import namedtuple -from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport import SPARK_HOME, USER_HOME, ERROR_CODES from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which import sparktestsupport.modules as modules + # ------------------------------------------------------------------------------------------------- # Functions for traversing module dependency graph # ------------------------------------------------------------------------------------------------- @@ -130,19 +131,6 @@ def determine_tags_to_exclude(changed_modules): # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- -def get_error_codes(err_code_file): - """Function to retrieve all block numbers from the `run-tests-codes.sh` - file to maintain backwards compatibility with the `run-tests-jenkins` - script""" - - with open(err_code_file, 'r') as f: - err_codes = [e.split()[1].strip().split('=') - for e in f if e.startswith("readonly")] - return dict(err_codes) - - -ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) - def determine_java_executable(): """Will return the path of the java executable that will be used by Spark's @@ -191,7 +179,7 @@ def determine_java_version(java_exe): def set_title_and_block(title, err_block): - os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] + os.environ["CURRENT_BLOCK"] = str(ERROR_CODES[err_block]) line_str = '=' * 72 print('') @@ -467,7 +455,7 @@ def main(): rm_r(os.path.join(USER_HOME, ".ivy2", "local", "org.apache.spark")) rm_r(os.path.join(USER_HOME, ".ivy2", "cache", "org.apache.spark")) - os.environ["CURRENT_BLOCK"] = ERROR_CODES["BLOCK_GENERAL"] + os.environ["CURRENT_BLOCK"] = str(ERROR_CODES["BLOCK_GENERAL"]) java_exe = determine_java_executable() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 12696d98fb988..8ab6d9e37ca2f 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -19,3 +19,17 @@ SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) USER_HOME = os.environ.get("HOME") +ERROR_CODES = { + "BLOCK_GENERAL": 10, + "BLOCK_RAT": 11, + "BLOCK_SCALA_STYLE": 12, + "BLOCK_PYTHON_STYLE": 13, + "BLOCK_R_STYLE": 14, + "BLOCK_DOCUMENTATION": 15, + "BLOCK_BUILD": 16, + "BLOCK_MIMA": 17, + "BLOCK_SPARK_UNIT_TESTS": 18, + "BLOCK_PYSPARK_UNIT_TESTS": 19, + "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_TIMEOUT": 124 +} diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index 12bd0bf3a4fe9..d280e797077d1 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -22,6 +22,36 @@ import sys +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output + subprocess_check_call = subprocess.check_call +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + + # backported from subprocess module in Python 2.7 + def subprocess_check_call(*popenargs, **kwargs): + retcode = call(*popenargs, **kwargs) + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise CalledProcessError(retcode, cmd) + return 0 + + def exit_from_command_with_retcode(cmd, retcode): print("[error] running", ' '.join(cmd), "; received return code", retcode) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) @@ -39,7 +69,7 @@ def rm_r(path): os.remove(path) -def run_cmd(cmd): +def run_cmd(cmd, return_output=False): """ Given a command as a list of arguments will attempt to execute the command and, on failure, print an error message and exit. @@ -48,7 +78,10 @@ def run_cmd(cmd): if not isinstance(cmd, list): cmd = cmd.split() try: - subprocess.check_call(cmd) + if return_output: + return subprocess_check_output(cmd) + else: + return subprocess_check_call(cmd) except subprocess.CalledProcessError as e: exit_from_command_with_retcode(e.cmd, e.returncode) diff --git a/python/run-tests.py b/python/run-tests.py index 152f5cc98d0fd..f5857f8c62214 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,23 +31,6 @@ import Queue else: import queue as Queue -if sys.version_info >= (2, 7): - subprocess_check_output = subprocess.check_output -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -55,7 +38,7 @@ def subprocess_check_output(*popenargs, **kwargs): from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which # noqa +from sparktestsupport.shellutils import which, subprocess_check_output # noqa from sparktestsupport.modules import all_modules # noqa From beb8bc1ea588b7f9ab7effff707c0f784421364d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 19 Oct 2015 00:06:51 -0700 Subject: [PATCH 097/139] [SPARK-11126][SQL] Fix the potential flaky test The unit test added in #9132 is flaky. This is a follow up PR to add `listenerBus.waitUntilEmpty` to fix it. Author: zsxwing Closes #9163 from zsxwing/SPARK-11126-follow-up. --- .../org/apache/spark/sql/execution/ui/SQLListenerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 03bcee94a2b91..c15aac775096c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -316,10 +316,12 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { test("SPARK-11126: no memory leak when running non SQL jobs") { val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } From bd64c2d550c36405f9be25a5c6a8eaa54bf4e7e7 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 19 Oct 2015 09:59:18 +0100 Subject: [PATCH 098/139] =?UTF-8?q?[SPARK-10921][YARN]=20Completely=20remo?= =?UTF-8?q?ve=20the=20use=20of=20SparkContext.prefer=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …redNodeLocationData Author: Jacek Laskowski Closes #8976 from jaceklaskowski/SPARK-10921. --- .../scala/org/apache/spark/SparkContext.scala | 22 +++++-------------- project/MimaExcludes.scala | 3 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 1 - .../spark/deploy/yarn/YarnRMClient.scala | 2 -- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0c72adfb9505b..ccba3ed9e643c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -90,11 +90,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // NOTE: this must be placed at the beginning of the SparkContext constructor. SparkContext.markPartiallyConstructed(this, allowMultipleContexts) - // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, - // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It - // contains a map from hostname to a list of input format splits on the host. - private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() - val startTime = System.currentTimeMillis() private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) @@ -116,16 +111,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Alternative constructor for setting preferred locations where Spark will create executors. * * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. - * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. + * @param preferredNodeLocationData not used. Left for backward compatibility. */ @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - this.preferredNodeLocationData = preferredNodeLocationData } /** @@ -147,10 +139,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. - * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. + * @param preferredNodeLocationData not used. Left for backward compatibility. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-10921", "1.6.0") def this( master: String, appName: String, @@ -163,7 +154,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (preferredNodeLocationData.nonEmpty) { logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") } - this.preferredNodeLocationData = preferredNodeLocationData } // NOTE: The below constructors could be consolidated using default arguments. Due to @@ -177,7 +167,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param appName A name for your application, to display on the cluster web UI. */ private[spark] def this(master: String, appName: String) = - this(master, appName, null, Nil, Map(), Map()) + this(master, appName, null, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -187,7 +177,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param sparkHome Location where Spark is installed on cluster nodes. */ private[spark] def this(master: String, appName: String, sparkHome: String) = - this(master, appName, sparkHome, Nil, Map(), Map()) + this(master, appName, sparkHome, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -199,7 +189,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * system or HDFS, HTTP, HTTPS, or FTP URLs. */ private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) = - this(master, appName, sparkHome, jars, Map(), Map()) + this(master, appName, sparkHome, jars, Map()) // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 08e4a449cf762..0872d3f3e7093 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -100,6 +100,9 @@ object MimaExcludes { "org.apache.spark.sql.SQLContext.setSession"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.sql.SQLContext.createSession") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.preferredNodeLocationData_=") ) case v if v.startsWith("1.5") => Seq( diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3791eea5bf178..d1d248bf79beb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -255,7 +255,6 @@ private[spark] class ApplicationMaster( driverRef, yarnConf, _sparkConf, - if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, securityMgr) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index df042bf291de7..d2a211f6711ff 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -49,7 +49,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. - * @param preferredNodeLocations Map with hints about where to allocate containers. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. */ @@ -58,7 +57,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, - preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, uiHistoryAddress: String, securityMgr: SecurityManager From dfa41e63b98c28b087c56f94658b5e99e8a7758c Mon Sep 17 00:00:00 2001 From: Alex Angelini Date: Mon, 19 Oct 2015 10:07:39 -0700 Subject: [PATCH 099/139] [SPARK-9643] Upgrade pyrolite to 4.9 Includes: https://github.com/irmen/Pyrolite/pull/23 which fixes datetimes with timezones. JoshRosen https://issues.apache.org/jira/browse/SPARK-9643 Author: Alex Angelini Closes #7950 from angelini/upgrade_pyrolite_up. --- core/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/pom.xml b/core/pom.xml index c0af98a04fb1d..fdcb6a7902bbd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -339,7 +339,7 @@ net.razorvine pyrolite - 4.4 + 4.9 net.razorvine From 4c33a34ba3167ae67fdb4978ea2166ce65638fb9 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 19 Oct 2015 10:46:10 -0700 Subject: [PATCH 100/139] =?UTF-8?q?[SPARK-10668]=20[ML]=20Use=20WeightedLe?= =?UTF-8?q?astSquares=20in=20LinearRegression=20with=20L=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …2 regularization if the number of features is small Author: lewuathe Author: Lewuathe Author: Kai Sasaki Author: Lewuathe Closes #8884 from Lewuathe/SPARK-10668. --- R/pkg/R/mllib.R | 5 +- R/pkg/inst/tests/test_mllib.R | 2 +- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 17 + .../apache/spark/ml/r/SparkRWrappers.scala | 4 +- .../ml/regression/LinearRegression.scala | 50 +- .../regression/JavaLinearRegressionSuite.java | 3 +- .../ml/regression/LinearRegressionSuite.scala | 1045 +++++++++-------- .../spark/ml/tuning/CrossValidatorSuite.scala | 2 +- .../ml/tuning/TrainValidationSplitSuite.scala | 2 +- 10 files changed, 640 insertions(+), 494 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cd00bbbeec698..25615e805e03c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -45,11 +45,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, + solver = "auto") { family <- match.arg(family) model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha) + alpha, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032f8ec68b9d0..3331ce738358c 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -59,7 +59,7 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) coefs <- as.vector(stats$coefficients) rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) expect_true(all(abs(rCoefs - coefs) < 1e-6)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8cb6b5493c61c..c7bca1243092c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -73,7 +73,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.")) + "all instance weights as 1.0."), + ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + + "empty, default value is 'auto'.", Some("\"auto\""))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e3625212e5251..cb2a060a34dd6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -357,4 +357,21 @@ private[ml] trait HasWeightCol extends Params { /** @group getParam */ final def getWeightCol: String = $(weightCol) } + +/** + * Trait for shared param solver (default: "auto"). + */ +private[ml] trait HasSolver extends Params { + + /** + * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.. + * @group param + */ + final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + setDefault(solver, "auto") + + /** @group getParam */ + final def getSolver: String = $(solver) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index f5a022c31ed90..fec61fed3cb9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -30,13 +30,15 @@ private[r] object SparkRWrappers { df: DataFrame, family: String, lambda: Double, - alpha: Double): PipelineModel = { + alpha: Double, + solver: String): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { case "gaussian" => new LinearRegression() .setRegParam(lambda) .setElasticNetParam(alpha) .setFitIntercept(formula.hasIntercept) + .setSolver(solver) case "binomial" => new LogisticRegression() .setRegParam(lambda) .setElasticNetParam(alpha) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index dd09667ef5a0f..573a61a6eabdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -43,7 +44,7 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization with HasWeightCol + with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver /** * :: Experimental :: @@ -130,9 +131,53 @@ class LinearRegression(override val uid: String) def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") + /** + * Set the solver algorithm used for optimization. + * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * The default value is "auto" which means that the solver algorithm is + * selected automatically. + * @group setParam + */ + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "auto") + override protected def train(dataset: DataFrame): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist instances. + // Extract the number of features before deciding optimization solver. + val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { + case Row(features: Vector) => features.size + }.toArray()(0) val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || + $(solver) == "normal") { + require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + + "solver is used.'") + // For low dimensional data, WeightedLeastSquares is more efficiently since the + // training algorithm only requires one pass through the data. (SPARK-10668) + val instances: RDD[WeightedLeastSquares.Instance] = dataset.select( + col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + WeightedLeastSquares.Instance(weight, features, label) + } + + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + $(standardization), true) + val model = optimizer.fit(instances) + // When it is trained by WeightedLeastSquares, training summary does not + // attached returned model. + val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) + // WeightedLeastSquares does not run through iterations. So it does not generate + // an objective history. + val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + $(featuresCol), + Array(0D)) + return lrModel.setSummary(trainingSummary) + } + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) @@ -155,7 +200,6 @@ class LinearRegression(override val uid: String) new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) } - val numFeatures = featuresSummarizer.mean.size val yMean = ySummarizer.mean(0) val yStd = math.sqrt(ySummarizer.variance(0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 91c589d00abd5..4fb0b0d1092b6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -61,6 +61,7 @@ public void tearDown() { public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); assertEquals("label", lr.getLabelCol()); + assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); @@ -75,7 +76,7 @@ public void linearRegressionWithSetters() { // Set params, train, and check as many params as we can. LinearRegression lr = new LinearRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0).setSolver("l-bfgs"); LinearRegressionModel model = lr.fit(dataset); LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 73a0a5caf8640..a6e0c72ba9030 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val seed: Int = 42 @transient var dataset: DataFrame = _ @transient var datasetWithoutIntercept: DataFrame = _ @@ -50,15 +51,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2)) /* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ datasetWithoutIntercept = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) - + 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2)) } test("params") { @@ -76,6 +76,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getElasticNetParam === 0.0) assert(lir.getFitIntercept) assert(lir.getStandardization) + assert(lir.getSolver == "auto") val model = lir.fit(dataset) // copied model must have the same parent. @@ -93,525 +94,603 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } test("linear regression with intercept without regularization") { - val trainer1 = new LinearRegression - // The result should be the same regardless of standardization without regularization - val trainer2 = (new LinearRegression).setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - label <- as.numeric(data$V1) - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.298698 - as.numeric.data.V2. 4.700706 - as.numeric.data.V3. 7.199082 - */ - val interceptR = 6.298698 - val weightsR = Vectors.dense(4.700706, 7.199082) - - assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) - assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) - - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = new LinearRegression().setSolver(solver) + // The result should be the same regardless of standardization without regularization + val trainer2 = (new LinearRegression).setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.298698 + as.numeric.data.V2. 4.700706 + as.numeric.data.V3. 7.199082 + */ + val interceptR = 6.298698 + val weightsR = Vectors.dense(4.700706, 7.199082) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } test("linear regression without intercept without regularization") { - val trainer1 = (new LinearRegression).setFitIntercept(false) - // Without regularization the results should be the same - val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false) - val model1 = trainer1.fit(dataset) - val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept) - val model2 = trainer2.fit(dataset) - val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept) - - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, - intercept = FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 - */ - val weightsR = Vectors.dense(6.995908, 5.275131) - - assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) - assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) - - /* - Then again with the data with no intercept: - > weightsWithoutIntercept - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data3.V2. 4.70011 - as.numeric.data3.V3. 7.19943 - */ - val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) - - assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3) - assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setFitIntercept(false).setSolver(solver) + // Without regularization the results should be the same + val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false) + .setSolver(solver) + val model1 = trainer1.fit(dataset) + val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept) + val model2 = trainer2.fit(dataset) + val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 + */ + val weightsR = Vectors.dense(6.995908, 5.275131) + + assert(model1.intercept ~== 0 absTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) + assert(model2.intercept ~== 0 absTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) + + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 + */ + val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) + + assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3) + } } test("linear regression with intercept with L1 regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - .setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 - */ - val interceptR1 = 6.24300 - val weightsR1 = Vectors.dense(4.024821, 6.679841) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.416948 - as.numeric.data.V2. 3.893869 - as.numeric.data.V3. 6.724286 - */ - val interceptR2 = 6.416948 - val weightsR2 = Vectors.dense(3.893869, 6.724286) - - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setSolver(solver).setStandardization(false) + + var model1: LinearRegressionModel = null + var model2: LinearRegressionModel = null + + // Normal optimizer is not supported with only L1 regularization case. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(dataset) + trainer2.fit(dataset) + } + } else { + model1 = trainer1.fit(dataset) + model2 = trainer2.fit(dataset) + + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 + */ + val interceptR1 = 6.24300 + val weightsR1 = Vectors.dense(4.024821, 6.679841) + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.416948 + as.numeric.data.V2. 3.893869 + as.numeric.data.V3. 6.724286 + */ + val interceptR2 = 6.416948 + val weightsR2 = Vectors.dense(3.893869, 6.724286) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression without intercept with L1 regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - .setFitIntercept(false) - val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) - .setFitIntercept(false).setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - intercept=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 - */ - val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(6.299752, 4.772913) - - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - intercept=FALSE, standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.232193 - as.numeric.data.V3. 4.764229 - */ - val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(6.232193, 4.764229) - - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + + var model1: LinearRegressionModel = null + var model2: LinearRegressionModel = null + + // Normal optimizer is not supported with only L1 regularization case. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(dataset) + trainer2.fit(dataset) + } + } else { + model1 = trainer1.fit(dataset) + model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(6.299752, 4.772913) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.232193 + as.numeric.data.V3. 4.764229 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(6.232193, 4.764229) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression with intercept with L2 regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - .setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 5.269376 - as.numeric.data.V2. 3.736216 - as.numeric.data.V3. 5.712356) - */ - val interceptR1 = 5.269376 - val weightsR1 = Vectors.dense(3.736216, 5.712356) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 5.791109 - as.numeric.data.V2. 3.435466 - as.numeric.data.V3. 5.910406 - */ - val interceptR2 = 5.791109 - val weightsR2 = Vectors.dense(3.435466, 5.910406) - - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.269376 + as.numeric.data.V2. 3.736216 + as.numeric.data.V3. 5.712356) + */ + val interceptR1 = 5.269376 + val weightsR1 = Vectors.dense(3.736216, 5.712356) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.791109 + as.numeric.data.V2. 3.435466 + as.numeric.data.V3. 5.910406 + */ + val interceptR2 = 5.791109 + val weightsR2 = Vectors.dense(3.435466, 5.910406) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } test("linear regression without intercept with L2 regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - .setFitIntercept(false) - val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) - .setFitIntercept(false).setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - intercept = FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 - */ - val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(5.522875, 4.214502) - - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - intercept = FALSE, standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 5.263704 - as.numeric.data.V3. 4.187419 - */ - val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(5.263704, 4.187419) - - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(5.522875, 4.214502) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.263704 + as.numeric.data.V3. 4.187419 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(5.263704, 4.187419) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } test("linear regression with intercept with ElasticNet regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - .setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 - */ - val interceptR1 = 5.696056 - val weightsR1 = Vectors.dense(3.670489, 6.001122) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 - standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.114723 - as.numeric.data.V2. 3.409937 - as.numeric.data.V3. 6.146531 - */ - val interceptR2 = 6.114723 - val weightsR2 = Vectors.dense(3.409937, 6.146531) - - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setStandardization(false).setSolver(solver) + + var model1: LinearRegressionModel = null + var model2: LinearRegressionModel = null + + // Normal optimizer is not supported with non-zero elasticnet parameter. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(dataset) + trainer2.fit(dataset) + } + } else { + model1 = trainer1.fit(dataset) + model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 + */ + val interceptR1 = 5.696056 + val weightsR1 = Vectors.dense(3.670489, 6.001122) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.114723 + as.numeric.data.V2. 3.409937 + as.numeric.data.V3. 6.146531 + */ + val interceptR2 = 6.114723 + val weightsR2 = Vectors.dense(3.409937, 6.146531) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression without intercept with ElasticNet regularization") { - val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - .setFitIntercept(false) - val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) - .setFitIntercept(false).setStandardization(false) - val model1 = trainer1.fit(dataset) - val model2 = trainer2.fit(dataset) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - intercept=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 - */ - val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(5.673348, 4.322251) - - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) - - /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - intercept=FALSE, standardize=FALSE)) - > weights - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 5.477988 - as.numeric.data.V3. 4.297622 - */ - val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(5.477988, 4.297622) - - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) - - model1.transform(dataset).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false).setSolver(solver) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false).setStandardization(false).setSolver(solver) + + var model1: LinearRegressionModel = null + var model2: LinearRegressionModel = null + + // Normal optimizer is not supported with non-zero elasticnet parameter. + if (solver == "normal") { + intercept[IllegalArgumentException] { + trainer1.fit(dataset) + trainer2.fit(dataset) + } + } else { + model1 = trainer1.fit(dataset) + model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(5.673348, 4.322251) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.477988 + as.numeric.data.V3. 4.297622 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(5.477988, 4.297622) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } } test("linear regression model training summary") { - val trainer = new LinearRegression - val model = trainer.fit(dataset) - val trainerNoPredictionCol = trainer.setPredictionCol("") - val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset) - - - // Training results for the model should be available - assert(model.hasSummary) - assert(modelNoPredictionCol.hasSummary) - - // Schema should be a superset of the input dataset - assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf( - model.summary.predictions.schema.fieldNames.toSet)) - // Validate that we re-insert a prediction column for evaluation - val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames - assert((dataset.schema.fieldNames.toSet).subsetOf( - modelNoPredictionColFieldNames.toSet)) - assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) - - // Residuals in [[LinearRegressionResults]] should equal those manually computed - val expectedResiduals = dataset.select("features", "label") - .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept - label - prediction - } - .zip(model.summary.residuals.map(_.getDouble(0))) - .collect() - .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver) + val model = trainer.fit(dataset) + val trainerNoPredictionCol = trainer.setPredictionCol("") + val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset) + + + // Training results for the model should be available + assert(model.hasSummary) + assert(modelNoPredictionCol.hasSummary) + + // Schema should be a superset of the input dataset + assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf( + model.summary.predictions.schema.fieldNames.toSet)) + // Validate that we re-insert a prediction column for evaluation + val modelNoPredictionColFieldNames + = modelNoPredictionCol.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + modelNoPredictionColFieldNames.toSet)) + assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = dataset.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + label - prediction + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } - /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 - */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) - - // Objective function should be monotonically decreasing for linear regression - assert( - model.summary - .objectiveHistory - .sliding(2) - .forall(x => x(0) >= x(1))) + /* + Use the following R code to generate model training results. + + predictions <- predict(fit, newx=features) + residuals <- label - predictions + > mean(residuals^2) # MSE + [1] 0.009720325 + > mean(abs(residuals)) # MAD + [1] 0.07863206 + > cor(predictions, label)^2# r^2 + [,1] + s0 0.9998749 + */ + assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) + assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) + assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + + // Normal solver uses "WeightedLeastSquares". This algorithm does not generate + // objective history because it does not run through iterations. + if (solver == "l-bfgs") { + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + } } test("linear regression model testset evaluation summary") { - val trainer = new LinearRegression - val model = trainer.fit(dataset) - - // Evaluating on training dataset should yield results summary equal to training summary - val testSummary = model.evaluate(dataset) - assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) - assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) - model.summary.residuals.select("residuals").collect() - .zip(testSummary.residuals.select("residuals").collect()) - .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver) + val model = trainer.fit(dataset) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(dataset) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } } - test("linear regression with weighted samples"){ - val (data, weightedData) = { - val activeData = LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) - - val rnd = new Random(8392) - val signedData = activeData.map { case p: LabeledPoint => - (rnd.nextGaussian() > 0.0, p) - } - - val data1 = signedData.flatMap { - case (true, p) => Iterator(p, p) - case (false, p) => Iterator(p) - } - - val weightedSignedData = signedData.flatMap { - case (true, LabeledPoint(label, features)) => - Iterator( - Instance(label, weight = 1.2, features), - Instance(label, weight = 0.8, features) - ) - case (false, LabeledPoint(label, features)) => - Iterator( - Instance(label, weight = 0.3, features), - Instance(label, weight = 0.1, features), - Instance(label, weight = 0.6, features) - ) + test("linear regression with weighted samples") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) } - val noiseData = LinearDataGenerator.generateLinearInput( - 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) - val weightedNoiseData = noiseData.map { - case LabeledPoint(label, features) => Instance(label, weight = 0, features) - } - val data2 = weightedSignedData ++ weightedNoiseData - - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + + // Normal optimizer is not supported with non-zero elasticnet parameter. + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) } - - val trainer1a = (new LinearRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) - val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) - val model1a0 = trainer1a.fit(data) - val model1a1 = trainer1a.fit(weightedData) - val model1b = trainer1b.fit(weightedData) - assert(model1a0.weights !~= model1a1.weights absTol 1E-3) - assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) - assert(model1a0.weights ~== model1b.weights absTol 1E-3) - assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) - - val trainer2a = (new LinearRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) - val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) - val model2a0 = trainer2a.fit(data) - val model2a1 = trainer2a.fit(weightedData) - val model2b = trainer2b.fit(weightedData) - assert(model2a0.weights !~= model2a1.weights absTol 1E-3) - assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) - assert(model2a0.weights ~== model2b.weights absTol 1E-3) - assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) - - val trainer3a = (new LinearRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) - val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) - val model3a0 = trainer3a.fit(data) - val model3a1 = trainer3a.fit(weightedData) - val model3b = trainer3b.fit(weightedData) - assert(model3a0.weights !~= model3a1.weights absTol 1E-3) - assert(model3a0.weights ~== model3b.weights absTol 1E-3) - - val trainer4a = (new LinearRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) - val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) - val model4a0 = trainer4a.fit(data) - val model4a1 = trainer4a.fit(weightedData) - val model4b = trainer4b.fit(weightedData) - assert(model4a0.weights !~= model4a1.weights absTol 1E-3) - assert(model4a0.weights ~== model4b.weights absTol 1E-3) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index fde02e0c84bc0..cbe09292a0337 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -69,7 +69,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - val trainer = new LinearRegression + val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() .addGrid(trainer.regParam, Array(1000.0, 0.001)) .addGrid(trainer.maxIter, Array(0, 10)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index ef24e6fb6b80f..5fb80091d0b4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -58,7 +58,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - val trainer = new LinearRegression + val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() .addGrid(trainer.regParam, Array(1000.0, 0.001)) .addGrid(trainer.maxIter, Array(0, 10)) From 7893cd95db5f2caba59ff5c859d7e4964ad7938d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 19 Oct 2015 11:02:26 -0700 Subject: [PATCH 101/139] [SPARK-11119] [SQL] cleanup for unsafe array and map The purpose of this PR is to keep the unsafe format detail only inside the unsafe class itself, so when we use them(like use unsafe array in unsafe map, use unsafe array and map in columnar cache), we don't need to understand the format before use them. change list: * unsafe array's 4-bytes numElements header is now required(was optional), and become a part of unsafe array format. * w.r.t the previous changing, the `sizeInBytes` of unsafe array now counts the 4-bytes header. * unsafe map's format was `[numElements] [key array numBytes] [key array content(without numElements header)] [value array content(without numElements header)]` before, which is a little hacky as it makes unsafe array's header optional. I think saving 4 bytes is not a big deal, so the format is now: `[key array numBytes] [unsafe key array] [unsafe value array]`. * w.r.t the previous changing, the `sizeInBytes` of unsafe map now counts both map's header and array's header. Author: Wenchen Fan Closes #9131 from cloud-fan/unsafe. --- .../catalyst/expressions/UnsafeArrayData.java | 43 +++++---- .../catalyst/expressions/UnsafeMapData.java | 88 +++++++++++++++---- .../catalyst/expressions/UnsafeReaders.java | 54 ------------ .../sql/catalyst/expressions/UnsafeRow.java | 8 +- .../codegen/UnsafeArrayWriter.java | 24 ++--- .../expressions/codegen/UnsafeRowWriter.java | 15 ---- .../codegen/GenerateUnsafeProjection.scala | 60 +++++++------ .../expressions/UnsafeRowConverterSuite.scala | 42 ++++----- .../spark/sql/columnar/ColumnType.scala | 30 ++++--- .../spark/sql/columnar/ColumnTypeSuite.scala | 2 +- 10 files changed, 174 insertions(+), 192 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4c63abb071e3b..761f0447943e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -30,19 +30,18 @@ /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * - * Each tuple has two parts: [offsets] [values] + * Each tuple has three parts: [numElements] [offsets] [values] * - * In the `offsets` region, we store 4 bytes per element, represents the start address of this - * element in `values` region. We can get the length of this element by subtracting next offset. + * The `numElements` is 4 bytes storing the number of elements of this array. + * + * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the + * base address of the array) of this element in `values` region. We can get the length of this + * element by subtracting next offset. * Note that offset can by negative which means this element is null. * * In the `values` region, we store the content of elements. As we can get length info, so elements * can be variable-length. * - * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, - * then follows content. When we read in an array, we should read first 4 bytes as `numElements` - * and take the rest as content. - * * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. @@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData { // The number of elements in this array private int numElements; - // The size of this array's backing data, in bytes + // The size of this array's backing data, in bytes. + // The 4-bytes header of `numElements` is also included. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + private int getElementOffset(int ordinal) { - return Platform.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -85,10 +89,6 @@ public Object[] array() { */ public UnsafeArrayData() { } - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - @Override public int numElements() { return numElements; } @@ -97,10 +97,13 @@ public UnsafeArrayData() { } * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param sizeInBytes the size of this row's backing data, in bytes + * @param sizeInBytes the size of this array's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the number of elements from the first 4 bytes. + final int numElements = Platform.getInt(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; this.baseObject = baseObject; this.baseOffset = baseOffset; @@ -277,7 +280,9 @@ public UnsafeArrayData getArray(int ordinal) { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } @Override @@ -286,7 +291,9 @@ public UnsafeMapData getMap(int ordinal) { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } @Override @@ -328,7 +335,7 @@ public UnsafeArrayData copy() { final byte[] arrayDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); - arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index e9dab9edb6bd1..5bebe2a96e391 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -17,41 +17,73 @@ package org.apache.spark.sql.catalyst.expressions; +import java.nio.ByteBuffer; + import org.apache.spark.sql.types.MapData; +import org.apache.spark.unsafe.Platform; /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * - * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. - * - * Note that when we write out this map, we should write out the `numElements` at first 4 bytes, - * and numBytes of key array at second 4 bytes, then follows key array content and value array - * content without `numElements` header. - * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as - * numBytes of key array, and construct unsafe key array and value array with these 2 information. + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head + * to indicate the number of bytes of the unsafe key array. + * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ +// TODO: Use a more efficient format which doesn't depend on unsafe array. public class UnsafeMapData extends MapData { - private final UnsafeArrayData keys; - private final UnsafeArrayData values; - // The number of elements in this array - private int numElements; - // The size of this array's backing data, in bytes + private Object baseObject; + private long baseOffset; + + // The size of this map's backing data, in bytes. + // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to + // 4 + key array numBytes + value array numBytes. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } - public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + private final UnsafeArrayData keys; + private final UnsafeArrayData values; + + /** + * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeMapData() { + keys = new UnsafeArrayData(); + values = new UnsafeArrayData(); + } + + /** + * Update this UnsafeMapData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this map's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the numBytes of key array from the first 4 bytes. + final int keyArraySize = Platform.getInt(baseObject, baseOffset); + final int valueArraySize = sizeInBytes - keyArraySize - 4; + assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; + assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; + + keys.pointTo(baseObject, baseOffset + 4, keyArraySize); + values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize); + assert keys.numElements() == values.numElements(); - this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); - this.numElements = keys.numElements(); - this.keys = keys; - this.values = values; + + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; } @Override public int numElements() { - return numElements; + return keys.numElements(); } @Override @@ -64,8 +96,26 @@ public UnsafeArrayData valueArray() { return values; } + public void writeToMemory(Object target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeMapData copy() { - return new UnsafeMapData(keys.copy(), values.copy()); + UnsafeMapData mapCopy = new UnsafeMapData(); + final byte[] mapDataCopy = new byte[sizeInBytes]; + Platform.copyMemory( + baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + return mapCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java deleted file mode 100644 index 6c5fcbca63fd7..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.unsafe.Platform; - -public class UnsafeReaders { - - /** - * Reads in unsafe array according to the format described in `UnsafeArrayData`. - */ - public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); - final UnsafeArrayData array = new UnsafeArrayData(); - // Skip the first 4 bytes. - array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); - return array; - } - - /** - * Reads in unsafe map according to the format described in `UnsafeMapData`. - */ - public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); - // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); - final int valueArraySize = numBytes - 8 - keyArraySize; - - final UnsafeArrayData keyArray = new UnsafeArrayData(); - keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); - - final UnsafeArrayData valueArray = new UnsafeArrayData(); - valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); - - return new UnsafeMapData(keyArray, valueArray); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 36859fbab9744..366615f6fe69f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -461,7 +461,9 @@ public UnsafeArrayData getArray(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } } @@ -473,7 +475,9 @@ public UnsafeMapData getMap(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 138178ce99d85..7f2a1cb07af01 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,17 +30,19 @@ public class UnsafeArrayWriter { private BufferHolder holder; + // The offset of the global buffer where we start to write this array. private int startingOffset; public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { - // We need 4 bytes each element to store offset. - final int fixedSize = 4 * numElements; + // We need 4 bytes to store numElements and 4 bytes each element to store offset. + final int fixedSize = 4 + 4 * numElements; this.holder = holder; this.startingOffset = holder.cursor; holder.grow(fixedSize); + Platform.putInt(holder.buffer, holder.cursor, numElements); holder.cursor += fixedSize; // Grows the global buffer ahead for fixed size data. @@ -48,7 +50,7 @@ public void initialize(BufferHolder holder, int numElements, int fixedElementSiz } private long getElementOffset(int ordinal) { - return startingOffset + 4 * ordinal; + return startingOffset + 4 + 4 * ordinal; } public void setNullAt(int ordinal) { @@ -132,20 +134,4 @@ public void write(int ordinal, CalendarInterval input) { // move the cursor forward. holder.cursor += 16; } - - - - // If this array is already an UnsafeArray, we don't need to go through all elements, we can - // directly write it. - public static void directWrite(BufferHolder holder, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes(); - - // grow the global buffer before writing data. - holder.grow(numBytes); - - // Writes the array content to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - holder.cursor += numBytes; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 8b7debd440031..e1f5a05d1d446 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -181,19 +181,4 @@ public void write(int ordinal, CalendarInterval input) { // move the cursor forward. holder.cursor += 16; } - - - - // If this struct is already an UnsafeRow, we don't need to go through all fields, we can - // directly write it. - public static void directWrite(BufferHolder holder, UnsafeRow input) { - // No need to zero-out the bytes as UnsafeRow is word aligned for sure. - final int numBytes = input.getSizeInBytes(); - // grow the global buffer before writing data. - holder.grow(numBytes); - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - // move the cursor forward. - holder.cursor += numBytes; - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1b957a508d10e..dbe92d6a83502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" if ($input instanceof UnsafeRow) { - $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input); + ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} } else { ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} } @@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodeGenContext, input: String, elementType: DataType, - bufferHolder: String, - needHeader: Boolean = true): String = { + bufferHolder: String): String = { val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val writeHeader = if (needHeader) { - // If header is required, we need to write the number of elements into first 4 bytes. - s""" - $bufferHolder.grow(4); - Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements); - $bufferHolder.cursor += 4; - """ - } else "" - s""" - final int $numElements = $input.numElements(); - $writeHeader if ($input instanceof UnsafeArrayData) { - $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input); + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { + final int $numElements = $input.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); for (int $index = 0; $index < $numElements; $index++) { @@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); + if ($input instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + } else { + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); - $bufferHolder.grow(8); + // preserve 4 bytes to write the key array numBytes later. + $bufferHolder.grow(4); + $bufferHolder.cursor += 4; - // Write the numElements into first 4 bytes. - Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements()); + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; - $bufferHolder.cursor += 8; - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + // Write the numBytes of key array into the first 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)} - // Write the numBytes of key array into second 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + } + """ + } - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)} + /** + * If the input is already in unsafe format, we don't need to go through all elements/fields, + * we can directly write it. + */ + private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + val sizeInBytes = ctx.freshName("sizeInBytes") + s""" + final int $sizeInBytes = $input.getSizeInBytes(); + // grow the global buffer before writing data. + $bufferHolder.grow($sizeInBytes); + $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); + $bufferHolder.cursor += $sizeInBytes; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c991cd86d28c8..c6aad34e972b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*)) } - private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes) - - private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes) - private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { assert(array.numElements == values.length) - assert(array.getSizeInBytes == (4 + 4) * values.length) + assert(array.getSizeInBytes == 4 + (4 + 4) * values.length) values.zipWithIndex.foreach { case (value, index) => assert(array.getInt(index) == value) } @@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { testArrayInt(map.keyArray, keys) testArrayInt(map.valueArray, values) - assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } test("basic conversion with array type") { @@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedArray = unsafeArray2.getArray(0) testArrayInt(nestedArray, Seq(3, 4)) - assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes)) + assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) - val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes) - val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes) + val array1Size = roundedSize(unsafeArray1.getSizeInBytes) + val array2Size = roundedSize(unsafeArray2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } @@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedMap = valueArray.getMap(0) testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) - assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes)) + assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes) } - assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) - val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes) - val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes) + val map1Size = roundedSize(unsafeMap1.getSizeInBytes) + val map2Size = roundedSize(unsafeMap2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } @@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerArray = field1.getArray(0) testArrayInt(innerArray, Seq(1)) - assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes)) + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes)) val field2 = unsafeRow.getArray(1) assert(field2.numElements == 1) @@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getLong(0) == 2L) } - assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } test("basic conversion with struct and map") { @@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes)) + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes)) val field2 = unsafeRow.getMap(1) @@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getSizeInBytes == 8 + 8) assert(innerStruct.getLong(0) == 4L) - assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) } - assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } test("basic conversion with array and map") { @@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes)) + assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes) val field2 = unsafeRow.getMap(1) assert(field2.numElements == 1) @@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) } - assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2bc2c96b61634..a41f04dd3b59a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo override def extract(buffer: ByteBuffer): UnsafeRow = { val sizeInBytes = buffer.getInt() assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() buffer.position(cursor + sizeInBytes) val unsafeRow = new UnsafeRow - unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) + unsafeRow.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numOfFields, + sizeInBytes) unsafeRow } @@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 4 + 4 + unsafeArray.getSizeInBytes + 4 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { - buffer.putInt(4 + value.getSizeInBytes) - buffer.putInt(value.numElements()) + buffer.putInt(value.getSizeInBytes) value.writeTo(buffer) } @@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) - UnsafeReaders.readArray( + val array = new UnsafeArrayData + array.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) + array } override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() @@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + 4 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { - buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes) - buffer.putInt(value.numElements()) - buffer.putInt(value.keyArray().getSizeInBytes) - value.keyArray().writeTo(buffer) - value.valueArray().writeTo(buffer) + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) } override def extract(buffer: ByteBuffer): UnsafeMapData = { @@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) - UnsafeReaders.readMap( + val map = new UnsafeMapData + map.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) + map } override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 0e6e1bcf72896..63bc39bfa0307 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) checkActualSize(ARRAY_TYPE, Array[Any](1), 16) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 25) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) checkActualSize(STRUCT_TYPE, Row("hello"), 28) } From 5966817941b57251fbd1cf8b9b458ec389c071a0 Mon Sep 17 00:00:00 2001 From: Rishabh Bhardwaj Date: Mon, 19 Oct 2015 14:38:49 -0700 Subject: [PATCH 102/139] [SPARK-11180][SQL] Support BooleanType in DataFrame.na.fill Added support for boolean types in fill and replace methods Author: Rishabh Bhardwaj Closes #9166 from rishabhbhardwaj/master. --- .../spark/sql/DataFrameNaFunctions.scala | 29 ++++++++++++------- .../spark/sql/DataFrameNaFunctionsSuite.scala | 14 +++++---- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 77a42c0873a6b..f7be5f6b370ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -198,7 +198,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: + * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -215,7 +216,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. * * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -232,7 +233,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ @@ -259,7 +261,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -282,8 +285,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. + * If `col` is "*", + * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -311,7 +316,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * Key and value of `replacement` map must have the same type, and + * can only be doubles , strings or booleans. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -333,15 +339,17 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] + // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] val replacementMap: Map[_, _] = replacement.head._2 match { case v: String => replacement + case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } } - // targetColumnType is either DoubleType or StringType + // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: jl.Boolean => BooleanType case _: String => StringType } @@ -367,7 +375,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // Check data type replaceValue match { - case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String => // This is good case _ => throw new IllegalArgumentException( s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") @@ -382,6 +390,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Double => fillCol[Double](f, v) case v: jl.Long => fillCol[Double](f, v.toDouble) case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } }.getOrElse(df.col(f.name)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 329ffb66083b1..e34875471f093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,24 +141,26 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( - (null, null, null, null)).toDF("a", "b", "c", "d") + val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false )), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) // Test Java version checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false ).asJava), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) } test("replace") { From 67582132bffbaaeaadc5cf8218f6239d03c39da0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 19 Oct 2015 15:35:14 -0700 Subject: [PATCH 103/139] [SPARK-11063] [STREAMING] Change preferredLocations of Receiver's RDD to hosts rather than hostports The format of RDD's preferredLocations must be hostname but the format of Streaming Receiver's scheduling executors is hostport. So it doesn't work. This PR converts `schedulerExecutors` to `hosts` before creating Receiver's RDD. Author: zsxwing Closes #9075 from zsxwing/SPARK-11063. --- .../scheduler/ReceiverSchedulingPolicy.scala | 3 ++- .../streaming/scheduler/ReceiverTracker.scala | 4 +++- .../scheduler/ReceiverTrackerSuite.scala | 24 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 10b5a7f57a802..d2b0be7f4a9c5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -21,6 +21,7 @@ import scala.collection.Map import scala.collection.mutable import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * A class that tries to schedule receivers with evenly distributed. There are two phases for @@ -79,7 +80,7 @@ private[streaming] class ReceiverSchedulingPolicy { return receivers.map(_.streamId -> Seq.empty).toMap } - val hostToExecutors = executors.groupBy(_.split(":")(0)) + val hostToExecutors = executors.groupBy(executor => Utils.parseHostPort(executor)._1) val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) val numReceiversOnExecutor = mutable.HashMap[String, Int]() // Set the initial value to 0 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d053e9e84910f..2ce80d618b0a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -551,7 +551,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (scheduledExecutors.isEmpty) { ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) + val preferredLocations = + scheduledExecutors.map(hostPort => Utils.parseHostPort(hostPort)._1).distinct + ssc.sc.makeRDD(Seq(receiver -> preferredLocations)) } receiverRDD.setName(s"Receiver $receiverId") ssc.sparkContext.setJobDescription(s"Streaming job running receiver $receiverId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 45138b748ecab..fda86aef457d4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLocality} +import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -80,6 +82,28 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") { + // Use ManualClock to prevent from starting batches so that we can make sure the only task is + // for starting the Receiver + val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc => + @volatile var receiverTaskLocality: TaskLocality = null + ssc.sparkContext.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + receiverTaskLocality = taskStart.taskInfo.taskLocality + } + }) + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + eventually(timeout(10 seconds), interval(10 millis)) { + // If preferredLocations is set correctly, receiverTaskLocality should be NODE_LOCAL + assert(receiverTaskLocality === TaskLocality.NODE_LOCAL) + } + } + } } /** An input DStream with for testing rate controlling */ From 7ab0ce6501c37f0fc3a49e3332573ae4e4def3e8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 19 Oct 2015 16:14:50 -0700 Subject: [PATCH 104/139] [SPARK-11131][CORE] Fix race in worker registration protocol. Because the registration RPC was not really an RPC, but a bunch of disconnected messages, it was possible for other messages to be sent before the reply to the registration arrived, and that would confuse the Worker. Especially in local-cluster mode, the worker was succeptible to receiving an executor request before it received a message from the master saying registration succeeded. On top of the above, the change also fixes a ClassCastException when the registration fails, which also affects the executor registration protocol. Because the `ask` is issued with a specific return type, if the error message (of a different type) was returned instead, the code would just die with an exception. This is fixed by having a common base trait for these reply messages. Author: Marcelo Vanzin Closes #9138 from vanzin/SPARK-11131. --- .../apache/spark/deploy/DeployMessage.scala | 7 +- .../apache/spark/deploy/master/Master.scala | 50 ++++++------- .../apache/spark/deploy/worker/Worker.scala | 73 ++++++++++++------- .../CoarseGrainedExecutorBackend.scala | 4 +- .../cluster/CoarseGrainedClusterMessage.scala | 4 + .../apache/spark/HeartbeatReceiverSuite.scala | 4 +- 6 files changed, 86 insertions(+), 56 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index d8084a57658ad..3feb7cea593e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -69,9 +69,14 @@ private[deploy] object DeployMessages { // Master to Worker + sealed trait RegisterWorkerResponse + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage + with RegisterWorkerResponse + + case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse - case class RegisterWorkerFailed(message: String) extends DeployMessage + case object MasterInStandby extends DeployMessage with RegisterWorkerResponse case class ReconnectWorker(masterUrl: String) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d518e92133aad..6715d6c70f497 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -233,31 +233,6 @@ private[deploy] class Master( System.exit(0) } - case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - workerHost, workerPort, cores, Utils.megabytesToString(memory))) - if (state == RecoveryState.STANDBY) { - // ignore, don't send response - } else if (idToWorker.contains(id)) { - workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) - } else { - val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerUiPort, publicAddress) - if (registerWorker(worker)) { - persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) - schedule() - } else { - val workerAddress = worker.endpoint.address - logWarning("Worker registration failed. Attempted to re-register worker at same " + - "address: " + workerAddress) - workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress)) - } - } - } - case RegisterApplication(description, driver) => { // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { @@ -387,6 +362,31 @@ private[deploy] class Master( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + workerHost, workerPort, cores, Utils.megabytesToString(memory))) + if (state == RecoveryState.STANDBY) { + context.reply(MasterInStandby) + } else if (idToWorker.contains(id)) { + context.reply(RegisterWorkerFailed("Duplicate worker ID")) + } else { + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + workerRef, workerUiPort, publicAddress) + if (registerWorker(worker)) { + persistenceEngine.addWorker(worker) + context.reply(RegisteredWorker(self, masterWebUiUrl)) + schedule() + } else { + val workerAddress = worker.endpoint.address + logWarning("Worker registration failed. Attempted to re-register worker at same " + + "address: " + workerAddress) + context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) + } + } + } + case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 93a1b3f310422..a45867e7680ec 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -26,7 +26,7 @@ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFut import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext -import scala.util.Random +import scala.util.{Failure, Random, Success} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -213,8 +213,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -271,8 +270,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -341,25 +339,54 @@ private[deploy] class Worker( } } - override def receive: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) - registered = true - changeMaster(masterRef, masterWebUiUrl) - forwordMessageScheduler.scheduleAtFixedRate(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - self.send(SendHeartbeat) - } - }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) - if (CLEANUP_ENABLED) { - logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = { + masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + .onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" + case Success(msg) => + Utils.tryLogNonFatalError { + handleRegisterResponse(msg) + } + case Failure(e) => + logError(s"Cannot register with master: ${masterEndpoint.address}", e) + System.exit(1) + }(ThreadUtils.sameThread) + } + + private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { + msg match { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + registered = true + changeMaster(masterRef, masterWebUiUrl) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(WorkDirCleanup) + self.send(SendHeartbeat) } - }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) - } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) + if (CLEANUP_ENABLED) { + logInfo( + s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) + } + case RegisterWorkerFailed(message) => + if (!registered) { + logError("Worker registration failed: " + message) + System.exit(1) + } + + case MasterInStandby => + // Ignore. Master not yet ready. + } + } + + override def receive: PartialFunction[Any, Unit] = synchronized { case SendHeartbeat => if (connected) { sendToMaster(Heartbeat(workerId, self)) } @@ -399,12 +426,6 @@ private[deploy] class Worker( map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) - case RegisterWorkerFailed(message) => - if (!registered) { - logError("Worker registration failed: " + message) - System.exit(1) - } - case ReconnectWorker(masterUrl) => logInfo(s"Master with url $masterUrl requested this worker to reconnect.") registerWithMaster() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 49059de50b42b..a9c6a05ecd434 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -59,12 +59,12 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisteredExecutor.type]( + ref.ask[RegisterExecutorResponse]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { - Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse } case Failure(e) => { logError(s"Cannot register with driver: $driverUrl", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index e0d25dc50c988..4652df32efa74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -36,9 +36,13 @@ private[spark] object CoarseGrainedClusterMessages { case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage + sealed trait RegisterExecutorResponse + case object RegisteredExecutor extends CoarseGrainedClusterMessage + with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage + with RegisterExecutorResponse // Executors to driver case class RegisterExecutor( diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 18f2229fea39b..3cd80c0f7d171 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -173,9 +173,9 @@ class HeartbeatReceiverSuite val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) From a1413b3662250dd5e980e8b1f7c3dc4585ab4766 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Oct 2015 16:16:31 -0700 Subject: [PATCH 105/139] [SPARK-11051][CORE] Do not allow local checkpointing after the RDD is materialized and checkpointed JIRA: https://issues.apache.org/jira/browse/SPARK-11051 When a `RDD` is materialized and checkpointed, its partitions and dependencies are cleared. If we allow local checkpointing on it and assign `LocalRDDCheckpointData` to its `checkpointData`. Next time when the RDD is materialized again, the error will be thrown. Author: Liang-Chi Hsieh Closes #9072 from viirya/no-localcheckpoint-after-checkpoint. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 35 +++++++++++++++---- .../org/apache/spark/CheckpointSuite.scala | 4 +++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a56e542242d5f..a97bb174438a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -294,7 +294,11 @@ abstract class RDD[T: ClassTag]( */ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) + if (isCheckpointedAndMaterialized) { + firstParent[T].iterator(split, context) + } else { + compute(split, context) + } } /** @@ -1520,20 +1524,37 @@ abstract class RDD[T: ClassTag]( persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) } - checkpointData match { - case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( - "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") - case _ => + // If this RDD is already checkpointed and materialized, its lineage is already truncated. + // We must not override our `checkpointData` in this case because it is needed to recover + // the checkpointed data. If it is overridden, next time materializing on this RDD will + // cause error. + if (isCheckpointedAndMaterialized) { + logWarning("Not marking RDD for local checkpoint because it was already " + + "checkpointed and materialized") + } else { + // Lineage is not truncated yet, so just override any existing checkpoint data with ours + checkpointData match { + case Some(_: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) } - checkpointData = Some(new LocalRDDCheckpointData(this)) this } /** - * Return whether this RDD is marked for checkpointing, either reliably or locally. + * Return whether this RDD is checkpointed and materialized, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + /** + * Return whether this RDD is checkpointed and materialized, either reliably or locally. + * This is introduced as an alias for `isCheckpointed` to clarify the semantics of the + * return value. Exposed for testing. + */ + private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed + /** * Return whether this RDD is marked for local checkpointing. * Exposed for testing. diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 4d70bfed909b6..119e5fc28e412 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -241,9 +241,13 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) + assert(rdd.isCheckpointedAndMaterialized === false) checkpoint(rdd, reliableCheckpoint) + assert(rdd.isCheckpointed === false) + assert(rdd.isCheckpointedAndMaterialized === false) assert(rdd.count() === 0) assert(rdd.isCheckpointed === true) + assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } From 232d7f8d42950431f1d9be2a6bb3591fb6ea20d6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 19 Oct 2015 16:18:20 -0700 Subject: [PATCH 106/139] [SPARK-11114][PYSPARK] add getOrCreate for SparkContext/SQLContext in Python Also added SQLContext.newSession() Author: Davies Liu Closes #9122 from davies/py_create. --- python/pyspark/context.py | 16 ++++++++++++++-- python/pyspark/sql/context.py | 27 +++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 14 ++++++++++++++ python/pyspark/tests.py | 4 ++++ 4 files changed, 59 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4969d85f52b23..afd74d937a413 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,7 +21,7 @@ import shutil import signal import sys -from threading import Lock +from threading import RLock from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -65,7 +65,7 @@ class SparkContext(object): _jvm = None _next_accum_id = 0 _active_spark_context = None - _lock = Lock() + _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -280,6 +280,18 @@ def __exit__(self, type, value, trace): """ self.stop() + @classmethod + def getOrCreate(cls, conf=None): + """ + Get or instantiate a SparkContext and register it as a singleton object. + + :param conf: SparkConf (optional) + """ + with SparkContext._lock: + if SparkContext._active_spark_context is None: + SparkContext(conf=conf or SparkConf()) + return SparkContext._active_spark_context + def setLogLevel(self, logLevel): """ Control our logLevel. This overrides any user-defined log settings. diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89c8c6e0d94f1..79453658a167a 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -75,6 +75,8 @@ class SQLContext(object): SQLContext in the JVM, instead we make all calls to this object. """ + _instantiatedContext = None + @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. @@ -99,6 +101,8 @@ def __init__(self, sparkContext, sqlContext=None): self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) install_exception_handler() + if SQLContext._instantiatedContext is None: + SQLContext._instantiatedContext = self @property def _ssql_ctx(self): @@ -111,6 +115,29 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + @classmethod + @since(1.6) + def getOrCreate(cls, sc): + """ + Get the existing SQLContext or create a new one with given SparkContext. + + :param sc: SparkContext + """ + if cls._instantiatedContext is None: + jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc()) + cls(sc, jsqlContext) + return cls._instantiatedContext + + @since(1.6) + def newSession(self): + """ + Returns a new SQLContext as new session, that has separate SQLConf, + registered temporary tables and UDFs, but shared SparkContext and + table cache. + """ + jsqlContext = self._ssql_ctx.newSession() + return self.__class__(self._sc, jsqlContext) + @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 645133b2b2d84..f465e1fa20941 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -174,6 +174,20 @@ def test_datetype_equal_zero(self): self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) +class SQLContextTests(ReusedPySparkTestCase): + def test_get_or_create(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) + + def test_new_session(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + sqlCtx.setConf("test_key", "a") + sqlCtx2 = sqlCtx.newSession() + sqlCtx2.setConf("test_key", "b") + self.assertEqual(sqlCtx.getConf("test_key", ""), "a") + self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") + + class SQLTests(ReusedPySparkTestCase): @classmethod diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc87e0c4b2c..3c51809444401 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1883,6 +1883,10 @@ def test_failed_sparkcontext_creation(self): # Regression test for SPARK-1550 self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) From fc26f32cf1bede8b9a1343dca0c0182107c9985e Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Mon, 19 Oct 2015 16:24:40 -0700 Subject: [PATCH 107/139] [SPARK-9708][MESOS] Spark should create local temporary directories in Mesos sandbox when launched with Mesos This is my own original work and I license this to the project under the project's open source license Author: Chris Bannister Author: Chris Bannister Closes #8358 from Zariel/mesos-local-dir. --- .../scala/org/apache/spark/util/Utils.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 22c05a2479422..55950405f0488 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -649,6 +649,7 @@ private[spark] object Utils extends Logging { * logic of locating the local directories according to deployment mode. */ def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { + val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -657,13 +658,23 @@ private[spark] object Utils extends Logging { getYarnLocalDirs(conf).split(",") } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) + } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { + conf.getenv("SPARK_LOCAL_DIRS").split(",") + } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + // Mesos already creates a directory per Mesos task. Spark should use that directory + // instead so all temporary files are automatically cleaned up when the Mesos task ends. + // Note that we don't want this if the shuffle service is enabled because we want to + // continue to serve shuffle files after the executors that wrote them have already exited. + Array(conf.getenv("MESOS_DIRECTORY")) } else { + if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { + logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + "spark.shuffle.service.enabled is enabled.") + } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user // configuration to point to a secure directory. So create a subdirectory with restricted // permissions under each listed directory. - Option(conf.getenv("SPARK_LOCAL_DIRS")) - .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - .split(",") + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(",") } } From 16906ef23a7aa2854c8cdcaa3bb3808ab39e0eec Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 19 Oct 2015 16:34:15 -0700 Subject: [PATCH 108/139] [SPARK-11120] Allow sane default number of executor failures when dynamically allocating in YARN I also added some information to container-failure error msgs about what host they failed on, which would have helped me identify the problem that lead me to this JIRA and PR sooner. Author: Ryan Williams Closes #9147 from ryan-williams/dyn-exec-failures. --- .../scala/org/apache/spark/SparkConf.scala | 4 +++- .../spark/deploy/yarn/ApplicationMaster.scala | 19 +++++++++++++++---- .../spark/deploy/yarn/YarnAllocator.scala | 19 +++++++++++-------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 1a0ac3d01759c..58d3b846fd80d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -595,7 +595,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.lookupTimeout" -> Seq( AlternateConfig("spark.akka.lookupTimeout", "1.4")), "spark.streaming.fileStream.minRememberDuration" -> Seq( - AlternateConfig("spark.streaming.minRememberDuration", "1.5")) + AlternateConfig("spark.streaming.minRememberDuration", "1.5")), + "spark.yarn.max.executor.failures" -> Seq( + AlternateConfig("spark.yarn.max.worker.failures", "1.5")) ) /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d1d248bf79beb..4b4d9990ce9f9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -62,10 +62,21 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", - math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) + // Default to twice the number of executors (twice the maximum number of executors if dynamic + // allocation is enabled), with a minimum of 3. + + private val maxNumExecutorFailures = { + val defaultKey = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + "spark.dynamicAllocation.maxExecutors" + } else { + "spark.executor.instances" + } + val effectiveNumExecutors = sparkConf.getInt(defaultKey, 0) + val defaultMaxNumExecutorFailures = math.max(3, 2 * effectiveNumExecutors) + + sparkConf.getInt("spark.yarn.max.executor.failures", defaultMaxNumExecutorFailures) + } @volatile private var exitCode = 0 @volatile private var unregistered = false diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 9e1ef1b3b4229..1deaa3743ddfa 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -430,17 +430,20 @@ private[yarn] class YarnAllocator( for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId val alreadyReleased = releasedContainers.remove(containerId) + val hostOpt = allocatedContainerToHostMap.get(containerId) + val onHostStr = hostOpt.map(host => s" on host: $host").getOrElse("") val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 - logInfo("Completed container %s (state: %s, exit status: %s)".format( + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, + onHostStr, completedContainer.getState, completedContainer.getExitStatus)) // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit + // now I think its ok as none of the containers are expected to exit. val exitStatus = completedContainer.getExitStatus val (isNormalExit, containerExitReason) = exitStatus match { case ContainerExitStatus.SUCCESS => @@ -449,7 +452,7 @@ private[yarn] class YarnAllocator( // Preemption should count as a normal exit, since YARN preempts containers merely // to do resource sharing, and tasks that fail due to preempted executors could // just as easily finish on any other executor. See SPARK-8167. - (true, s"Container $containerId was preempted.") + (true, s"Container ${containerId}${onHostStr} was preempted.") // Should probably still count memory exceeded exit codes towards task failures case VMEM_EXCEEDED_EXIT_CODE => (false, memLimitExceededLogMessage( @@ -461,7 +464,7 @@ private[yarn] class YarnAllocator( PMEM_EXCEEDED_PATTERN)) case unknown => numExecutorsFailed += 1 - (false, "Container marked as failed: " + containerId + + (false, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) @@ -479,10 +482,10 @@ private[yarn] class YarnAllocator( s"Container $containerId exited from explicit termination request.") } - if (allocatedContainerToHostMap.contains(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get - + for { + host <- hostOpt + containerSet <- allocatedHostToContainersMap.get(host) + } { containerSet.remove(containerId) if (containerSet.isEmpty) { allocatedHostToContainersMap.remove(host) From 8b877cc4ee46cad9d1f7cac451801c1410f6c1fe Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 19 Oct 2015 16:57:20 -0700 Subject: [PATCH 109/139] [SPARK-11088][SQL] Merges partition values using UnsafeProjection `DataSourceStrategy.mergeWithPartitionValues` is essentially a projection implemented in a quite inefficient way. This PR optimizes this method with `UnsafeProjection` to avoid unnecessary boxing costs. Author: Cheng Lian Closes #9104 from liancheng/spark-11088.faster-partition-values-merging. --- .../datasources/DataSourceStrategy.scala | 73 ++++++------------- 1 file changed, 24 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 33181fa6c065f..ffb4645b89321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -140,29 +140,30 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) + val partitionColumnNames = partitionColumns.fieldNames.toSet // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder // will union all partitions and attach partition values if needed. val scanBuilder = { - (columns: Seq[Attribute], filters: Array[Filter]) => { + (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { + val requiredDataColumns = + requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) + // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - val partitionColNames = partitionColumns.fieldNames - // Don't scan any partition columns to save I/O. Here we are being optimistic and // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. - val needed = columns.filterNot(a => partitionColNames.contains(a.name)) - val dataRows = - relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) + val dataRows = relation.buildScan( + requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( - relation.schema, - columns.map(_.name).toArray, - partitionColNames, + requiredColumns, + requiredDataColumns, + partitionColumns, partitionValues, - toCatalystRDD(logicalRelation, needed, dataRows)) + toCatalystRDD(logicalRelation, requiredDataColumns, dataRows)) } val unionedRows = @@ -188,52 +189,27 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { sparkPlan } - // TODO: refactor this thing. It is very complicated because it does projection internally. - // We should just put a project on top of this. private def mergeWithPartitionValues( - schema: StructType, - requiredColumns: Array[String], - partitionColumns: Array[String], + requiredColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + partitionColumnSchema: StructType, partitionValues: InternalRow, dataRows: RDD[InternalRow]): RDD[InternalRow] = { - val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains) - // If output columns contain any partition column(s), we need to merge scanned data // columns and requested partition columns to form the final result. - if (!requiredColumns.sameElements(nonPartitionColumns)) { - val mergers = requiredColumns.zipWithIndex.map { case (name, index) => - // To see whether the `index`-th column is a partition column... - val i = partitionColumns.indexOf(name) - if (i != -1) { - val dt = schema(partitionColumns(i)).dataType - // If yes, gets column value from partition values. - (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues.get(i, dt) - } - } else { - // Otherwise, inherits the value from scanned data. - val i = nonPartitionColumns.indexOf(name) - val dt = schema(nonPartitionColumns(i)).dataType - (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow.get(i, dt) - } - } + if (requiredColumns != dataColumns) { + // Builds `AttributeReference`s for all partition columns so that we can use them to project + // required partition columns. Note that if a partition column appears in `requiredColumns`, + // we should use the `AttributeReference` in `requiredColumns`. + val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap + val partitionColumns = partitionColumnSchema.toAttributes.map { a => + requiredColumnMap.getOrElse(a.name, a) } - // Since we know for sure that this closure is serializable, we can avoid the overhead - // of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally - // this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718). val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { - val dataTypes = requiredColumns.map(schema(_).dataType) - val mutableRow = new SpecificMutableRow(dataTypes) - iterator.map { dataRow => - var i = 0 - while (i < mutableRow.numFields) { - mergers(i)(mutableRow, dataRow, i) - i += 1 - } - mutableRow.asInstanceOf[InternalRow] - } + val projection = UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) + val mutableJoinedRow = new JoinedRow() + iterator.map(dataRow => projection(mutableJoinedRow(dataRow, partitionValues))) } // This is an internal RDD whose call site the user should not be concerned with @@ -242,7 +218,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Utils.withDummyCallSite(dataRows.sparkContext) { new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) } - } else { dataRows } From 8f74aa639759f400120794355511327fa74905da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Tue, 20 Oct 2015 08:45:39 +0100 Subject: [PATCH 110/139] [SPARK-10876] Display total uptime for completed applications MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #9059 from jbonofre/SPARK-10876. --- .../org/apache/spark/ui/jobs/AllJobsPage.scala | 18 +++++++++++------- .../spark/ui/jobs/JobProgressListener.scala | 7 ++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 041cd55ea483b..d467dd9e1f29d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -265,6 +265,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val listener = parent.jobProgresslistener listener.synchronized { val startTime = listener.startTime + val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq val completedJobs = listener.completedJobs.reverse.toSeq val failedJobs = listener.failedJobs.reverse.toSeq @@ -289,13 +290,16 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val summary: NodeSeq =