From b8634df1f1eb6ce909bec779522c9c9912c7d06a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 12 Sep 2014 01:37:59 -0700 Subject: [PATCH 01/36] [SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-allocated nodes, parentImpurities arrays. Memory calc bug fix. This PR includes some code simplifications and re-organization which will be helpful for implementing random forests. The main changes are that the nodes and parentImpurities arrays are no longer pre-allocated in the main train() method. Also added 2 bug fixes: * maxMemoryUsage calculation * over-allocation of space for bins in DTStatsAggregator for unordered features. Relation to RFs: * Since RFs will be deeper and will therefore be more likely sparse (not full trees), it could be a cost savings to avoid pre-allocating a full tree. * The associated re-organization also reduces bookkeeping, which will make RFs easier to implement. * The return code doneTraining may be generalized to include cases such as nodes ready for local training. Details: No longer pre-allocate parentImpurities array in main train() method. * parentImpurities values are now stored in individual nodes (in Node.stats.impurity). * These were not really needed. They were used in calculateGainForSplit(), but they can be calculated anyways using parentNodeAgg. No longer using Node.build since tree structure is constructed on-the-fly. * Did not eliminate since it is public (Developer) API. Marked as deprecated. Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. * Moved tree construction from main train() method into findBestSplitsPerGroup() since there is no need to keep the (split, gain) array for an entire level of nodes. Only one element of that array is needed at a time, so we do not the array. findBestSplits() now returns 2 items: * rootNode (newly created root node on first iteration, same root node on later iterations) * doneTraining (indicating if all nodes at that level were leafs) Updated DecisionTreeSuite. Notes: * Improved test "Second level node building with vs. without groups" ** generateOrderedLabeledPoints() modified so that it really does require 2 levels of internal nodes. * Related update: Added Node.deepCopy (private[tree]), used for test suite CC: mengxr Author: Joseph K. Bradley Closes #2341 from jkbradley/dt-spark-3160 and squashes the following commits: 07dd1ee [Joseph K. Bradley] Fixed overflow bug with computing maxMemoryUsage in DecisionTree. Also fixed bug with over-allocating space in DTStatsAggregator for unordered features. debe072 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: --- .../spark/mllib/tree/DecisionTree.scala | 191 +++++------- .../mllib/tree/configuration/Strategy.scala | 3 + .../mllib/tree/impl/DTStatsAggregator.scala | 11 +- .../tree/impl/DecisionTreeMetadata.scala | 3 +- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 37 +++ .../spark/mllib/tree/DecisionTreeSuite.scala | 277 +++++++++--------- 7 files changed, 268 insertions(+), 256 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 98596569b8c95..56bb8812100a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val maxDepth = strategy.maxDepth require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 - val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) - // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodesPlus1) - // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") // TODO: Calculate memory usage more precisely. val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) @@ -120,81 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * beforehand and is not used in later levels. */ + var topNode: Node = null // set on first iteration var level = 0 var break = false while (level <= maxDepth && !break) { - logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] = - DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, + metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = Node.startIndexInLevel(level) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - val nodeIndex = levelNodeIndexOffset + index - - // Extract info for this node (index) at the current level. - timer.start("extractNodeInfo") - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val predict = nodeSplitStats._3.predict - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - timer.stop("extractNodeInfo") - - if (level != 0) { - // Set parent. - val parentNodeIndex = Node.parentIndex(nodeIndex) - if (Node.isLeftChild(nodeIndex)) { - nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) - } else { - nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) - } - } - // Extract info for nodes at the next lower level. - timer.start("extractInfoForLowerLevels") - if (level < maxDepth) { - val leftChildIndex = Node.leftChildIndex(nodeIndex) - val leftImpurity = stats.leftImpurity - logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) - parentImpurities(leftChildIndex) = leftImpurity - - val rightChildIndex = Node.rightChildIndex(nodeIndex) - val rightImpurity = stats.rightImpurity - logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) - parentImpurities(rightChildIndex) = rightImpurity - } - timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + split) + if (level == 0) { + topNode = tmpTopNode } - require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves. - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) - logDebug("all leaf = " + allLeaf) - if (allLeaf) { - break = true // no more tree construction - } else { - level += 1 + if (doneTraining) { + break = true + logDebug("done training") } + + level += 1 } logDebug("#####################################") logDebug("Extracting tree model") logDebug("#####################################") - // Initialize the top or root node of the tree. - val topNode = nodes(1) - // Build the full tree using the node info calculated in the level-wise best split calculations. - topNode.build(nodes) - timer.stop("total") logInfo("Internal timing for DecisionTree:") @@ -409,24 +357,26 @@ object DecisionTree extends Serializable with Logging { * multiple groups if the level-wise training task could lead to memory overflow. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array (over nodes) of splits with best split for each node at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private[tree] def findBestSplits( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = { + timer: TimeTracker = new TimeTracker): (Node, Boolean) = { + // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -435,18 +385,18 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats, Predict)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 + var doneTraining = true while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, - nodes, splits, bins, timer, numGroups, groupIndex) - bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, + topNode, splits, bins, timer, numGroups, groupIndex) + doneTraining = doneTraining && doneTrainingGroup groupIndex += 1 } - bestSplits + (topNode, doneTraining) // Not first iteration, so topNode was already set. } else { - findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) + findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) } } @@ -586,27 +536,27 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return array of splits with best splits for all nodes at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private def findBestSplitsPerGroup( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = { + groupIndex: Int = 0): (Node, Boolean) = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -663,7 +613,7 @@ object DecisionTree extends Serializable with Logging { 0 } else { val globalNodeIndex = - predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) + predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) globalNodeIndex - globalNodeIndexOffset } } @@ -706,33 +656,63 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes) - // Iterating over all nodes at this level + // On the first iteration, we need to get and return the newly created root node. + var newTopNode: Node = topNode + + // Iterate over all nodes at this level var nodeIndex = 0 + var internalNodeCount = 0 while (nodeIndex < numNodes) { - val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) - logDebug("node impurity = " + nodeImpurity) - bestSplits(nodeIndex) = - binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) - logDebug("best split = " + bestSplits(nodeIndex)._1) + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) + logDebug("best split = " + split) + + val globalNodeIndex = globalNodeIndexOffset + nodeIndex + + // Extract info for this node at the current level. + val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) + val node = + new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + + if (!isLeaf) { + internalNodeCount += 1 + } + if (level == 0) { + newTopNode = node + } else { + // Set parent. + val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) + if (Node.isLeftChild(globalNodeIndex)) { + parentNode.leftNode = Some(node) + } else { + parentNode.rightNode = Some(node) + } + } + if (level < metadata.maxDepth) { + logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + + ", impurity = " + stats.rightImpurity) + } + nodeIndex += 1 } timer.stop("chooseSplits") - bestSplits + val doneTraining = internalNodeCount == 0 + (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftImpurityCalculator.count @@ -747,14 +727,10 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - // impurity of parent node - val impurity = if (level > 0) { - topImpurity - } else { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - parentNodeAgg.calculate() - } + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + + val impurity = parentNodeAgg.calculate() val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -795,19 +771,15 @@ object DecisionTree extends Serializable with Logging { * Find the best split for a node. * @param binAggregates Bin statistics. * @param nodeIndex Index for node to split in this (level, group). - * @param nodeImpurity Impurity of the node (nodeIndex). * @return tuple for best split: (Split, information gain) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - nodeImpurity: Double, level: Int, metadata: DecisionTreeMetadata, splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { - logDebug("node impurity = " + nodeImpurity) - // calculate predict only once var predict: Option[Predict] = None @@ -831,8 +803,7 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -845,8 +816,7 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -917,8 +887,7 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -937,8 +906,8 @@ object DecisionTree extends Serializable with Logging { /** * Get the number of values to be stored per node in the bin aggregates. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { - val totalBins = metadata.numBins.sum + private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { + val totalBins = metadata.numBins.map(_.toLong).sum if (metadata.isClassification) { metadata.numClasses * totalBins } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 987fe632c91ed..31d1e8ac30eea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -75,6 +75,9 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + val isMulticlassClassification = algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 866d85a79bea1..61a94246711bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator( * Offset for each feature for calculating indices into the [[allStats]] array. */ private val featureOffsets: Array[Int] = { - def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { - if (isUnordered(featureIndex)) { - total + 2 * numBins(featureIndex) - } else { - total + numBins(featureIndex) - } - } - Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } /** @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator( s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + s" but was called for ordered feature $featureIndex.") val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 5ceaa8154d11a..b6d49e5555b1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata( val numBins: Array[Int], val impurity: Impurity, val quantileStrategy: QuantileStrategy, + val maxDepth: Int, val minInstancesPerNode: Int, val minInfoGain: Double) extends Serializable { @@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, strategy.minInstancesPerNode, strategy.minInfoGain) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0594fd0749d21..271b2c4ad813e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Predict values for the given data set using the model trained. * * @param features RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return RDD of predictions for each of the given data points */ def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5b8a4cbed2306..5f0095d23c7ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -55,6 +55,8 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @deprecated("build should no longer be used since trees are constructed on-the-fly in training", + "1.2.0") def build(nodes: Array[Node]): Unit = { logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) @@ -93,6 +95,23 @@ class Node ( } } + /** + * Returns a deep copy of the subtree rooted at this node. + */ + private[tree] def deepCopy(): Node = { + val leftNodeCopy = if (leftNode.isEmpty) { + None + } else { + Some(leftNode.get.deepCopy()) + } + val rightNodeCopy = if (rightNode.isEmpty) { + None + } else { + Some(rightNode.get.deepCopy()) + } + new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + } + /** * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. @@ -190,4 +209,22 @@ private[tree] object Node { */ def startIndexInLevel(level: Int): Int = 1 << level + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: Node): Node = { + var tmpNode: Node = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftNode.get + } else { + tmpNode = tmpNode.rightNode.get + } + levelsToGo -= 1 + } + tmpNode + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index fd8547c1660fc..1bd7ea05c46c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -270,19 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 0) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode: Node, doneTraining: Boolean) = + DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 - val predict = bestSplits(0)._3 + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(predict.predict === 1) - assert(predict.prob === 0.6) + assert(rootNode.predict === 1) assert(stats.impurity > 0.2) } @@ -303,19 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories.length === 1) assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 - val predict = bestSplits(0)._3.predict + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(predict === 0.6) + assert(rootNode.predict === 0.6) assert(stats.impurity > 0.2) } @@ -356,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) } test("Binary classification stump with fixed label 1 for Gini") { @@ -382,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -409,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -436,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Second level node building with vs. without groups") { @@ -459,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](8) - nodes(1) = modelOneNode.topNode - nodes(1).leftNode = None - nodes(1).rightNode = None - - val parentImpurities = Array(0, 0.5, 0.5, 0.5) + val rootNodeCopy1 = modelOneNode.topNode.deepCopy() + val rootNodeCopy2 = modelOneNode.topNode.deepCopy() // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, - splits, bins, 10) - assert(bestSplits.length === 2) - assert(bestSplits(0)._2.gain > 0) - assert(bestSplits(1)._2.gain > 0) + val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy1, splits, bins, 10) + assert(rootNode.leftNode.nonEmpty) + assert(rootNode.rightNode.nonEmpty) + val children1 = new Array[Node](2) + children1(0) = rootNode.leftNode.get + children1(1) = rootNode.rightNode.get // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, - nodes, splits, bins, 0) - assert(bestSplitsWithGroups.length === 2) - assert(bestSplitsWithGroups(0)._2.gain > 0) - assert(bestSplitsWithGroups(1)._2.gain > 0) + val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy2, splits, bins, 0) + assert(rootNode2.leftNode.nonEmpty) + assert(rootNode2.rightNode.nonEmpty) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get // Verify whether the splits obtained using single group and multiple group level // construction strategies are the same. - for (i <- 0 until bestSplits.length) { - assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) - assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) - assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) - assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) - assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict) + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict === children2(i).predict) } } @@ -508,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { @@ -573,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) - val gain = bestSplits(0)._2 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) + + val gain = rootNode.stats.get assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) } @@ -600,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } @@ -627,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } test("Multiclass classification stump with 10-ary (ordered) categorical features") { @@ -652,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1.0)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) } test("Multiclass classification tree with 10-ary (ordered) categorical features," + @@ -698,12 +707,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestInfoStats = bestSplits(0)._2 - assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) } test("don't choose split that doesn't satisfy min instance per node requirements") { @@ -722,14 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestSplit = bestSplits(0)._1 - val bestSplitStats = bestSplits(0)._1 - assert(bestSplit.feature == 1) - assert(bestSplitStats != InformationGainStats.invalidInformationGainStats) + val split = rootNode.split.get + val gain = rootNode.stats.get + assert(split.feature == 1) + assert(gain != InformationGainStats.invalidInformationGainStats) } test("split must satisfy min info gain requirements") { @@ -754,12 +761,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestInfoStats = bestSplits(0)._2 - assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) } } @@ -786,13 +792,16 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600) { - val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + val label = if (i < 100) { + 0.0 + } else if (i < 500) { + 1.0 + } else if (i < 900) { + 0.0 } else { - val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + 1.0 } + arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i)) } arr } From f116f76bf1f1610905ca094c8edc53151a78d2f4 Mon Sep 17 00:00:00 2001 From: "Mark G. Whitney" Date: Fri, 12 Sep 2014 08:08:58 -0500 Subject: [PATCH 02/36] [SPARK-2558][DOCS] Add --queue example to YARN doc Put original YARN queue spark-submit arg description in running-on-yarn html table and example command line Author: Mark G. Whitney Closes #2218 from kramimus/2258-yarndoc and squashes the following commits: 4b5d808 [Mark G. Whitney] remove yarn queue config f8cda0d [Mark G. Whitney] [SPARK-2558][DOCS] Add spark.yarn.queue description to YARN doc --- docs/running-on-yarn.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d8b22f3663d08..212248bcce1c1 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -155,6 +155,7 @@ For example: --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ + --queue thequeue \ lib/spark-examples*.jar \ 10 From 533377621f1e178e18fa0b79d653a11b66e4e250 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Fri, 12 Sep 2014 09:46:21 -0700 Subject: [PATCH 03/36] [PySpark] Add blank line so that Python RDD.top() docstring renders correctly Author: RJ Nowling Closes #2370 from rnowling/python_rdd_docstrings and squashes the following commits: 5230574 [RJ Nowling] Add blank line so that Python RDD.top() docstring renders correctly --- python/pyspark/rdd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5667154cb84a8..6ad5ab2a2d1ae 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1060,6 +1060,7 @@ def top(self, num, key=None): Get the top N elements from a RDD. Note: It returns the list sorted in descending order. + >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) From 8194fc662c08eb445444c207264e22361def54ea Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 12 Sep 2014 11:29:30 -0700 Subject: [PATCH 04/36] [SPARK-3481] [SQL] Eliminate the error log in local Hive comparison test Logically, we should remove the Hive Table/Database first and then reset the Hive configuration, repoint to the new data warehouse directory etc. Otherwise it raised exceptions like "Database doesn't not exists: default" in the local testing. Author: Cheng Hao Closes #2352 from chenghao-intel/test_hive and squashes the following commits: 74fd76b [Cheng Hao] eliminate the error log --- .../org/apache/spark/sql/hive/TestHive.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 6974f3e581b97..a3bfd3a8f1fd2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -376,15 +376,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - // It is important that we RESET first as broken hooks that might have been set could break - // other sql exec here. - runSqlHive("RESET") - // For some reason, RESET does not reset the following variables... - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") - // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - loadedTables.clear() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") @@ -410,6 +401,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { FunctionRegistry.unregisterTemporaryUDF(udfName) } + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") configure() runSqlHive("USE default") From eae81b0bfdf3159be90f507a03853800aec1874a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 12 Sep 2014 13:43:29 -0700 Subject: [PATCH 05/36] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #930 (close requested by 'andrewor14') Closes #867 (close requested by 'marmbrus') Closes #1829 (close requested by 'marmbrus') Closes #1131 (close requested by 'JoshRosen') Closes #1571 (close requested by 'andrewor14') Closes #2359 (close requested by 'andrewor14') From 15a564598fe63003652b1e24527c432080b5976c Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 12 Sep 2014 14:08:38 -0700 Subject: [PATCH 06/36] [SPARK-3427] [GraphX] Avoid active vertex tracking in static PageRank GraphX's current implementation of static (fixed iteration count) PageRank uses the Pregel API. This unnecessarily tracks active vertices, even though in static PageRank all vertices are always active. Active vertex tracking incurs the following costs: 1. A shuffle per iteration to ship the active sets to the edge partitions. 2. A hash table creation per iteration at each partition to index the active sets for lookup. 3. A hash lookup per edge to check whether the source vertex is active. I reimplemented static PageRank using the lower-level GraphX API instead of the Pregel API. In benchmarks on a 16-node m2.4xlarge cluster, this provided a 23% speedup (from 514 s to 397 s, mean over 3 trials) for 10 iterations of PageRank on a synthetic graph with 10M vertices and 1.27B edges. Author: Ankur Dave Closes #2308 from ankurdave/SPARK-3427 and squashes the following commits: 449996a [Ankur Dave] Avoid unnecessary active vertex tracking in static PageRank --- .../apache/spark/graphx/lib/PageRank.scala | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 614555a054dfb..257e2f3a36115 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -79,30 +79,43 @@ object PageRank extends Logging { def run[VD: ClassTag, ED: ClassTag]( graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = { - // Initialize the pagerankGraph with each edge attribute having + // Initialize the PageRank graph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. - val pagerankGraph: Graph[Double, Double] = graph + var rankGraph: Graph[Double, Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to the initial pagerank values - .mapVertices( (id, attr) => 1.0 ) - .cache() + .mapVertices( (id, attr) => resetProb ) - // Define the three functions needed to implement PageRank in the GraphX - // version of Pregel - def vertexProgram(id: VertexId, attr: Double, msgSum: Double): Double = - resetProb + (1.0 - resetProb) * msgSum - def sendMessage(edge: EdgeTriplet[Double, Double]) = - Iterator((edge.dstId, edge.srcAttr * edge.attr)) - def messageCombiner(a: Double, b: Double): Double = a + b - // The initial message received by all vertices in PageRank - val initialMessage = 0.0 + var iteration = 0 + var prevRankGraph: Graph[Double, Double] = null + while (iteration < numIter) { + rankGraph.cache() - // Execute pregel for a fixed number of iterations. - Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and + // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. + val rankUpdates = rankGraph.mapReduceTriplets[Double]( + e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + + // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices + // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the + // edge partitions. + prevRankGraph = rankGraph + rankGraph = rankGraph.joinVertices(rankUpdates) { + (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum + }.cache() + + rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices + logInfo(s"PageRank finished iteration $iteration.") + prevRankGraph.vertices.unpersist(false) + prevRankGraph.edges.unpersist(false) + + iteration += 1 + } + + rankGraph } /** From 1d767967e925f1d727957c2d43383ef6ad2c5d5e Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 12 Sep 2014 16:48:28 -0500 Subject: [PATCH 07/36] SPARK-3014. Log a more informative messages in a couple failure scenario... ...s Author: Sandy Ryza Closes #1934 from sryza/sandy-spark-3014 and squashes the following commits: ae19cc1 [Sandy Ryza] SPARK-3014. Log a more informative messages in a couple failure scenarios --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 6 ++++-- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 6 ++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0fdb5ae3c2e40..5ed3575816a38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{Modifier, InvocationTargetException} import java.net.URL import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -323,7 +323,9 @@ object SparkSubmit { } val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } try { mainMethod.invoke(null, childArgs.toArray) } catch { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 878b6db546032..735d7723b0ce6 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -283,11 +283,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } val sparkContext = sparkContextRef.get() - assert(sparkContext != null || count >= numTries) if (sparkContext == null) { - logError( - "Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".format( - count * waitTime, numTries)) + logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier" + + " log output for errors. Failing the application.").format(numTries * waitTime)) } sparkContext } From af2583826c15d2a4e2732017ea20feeff0fb79f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 12 Sep 2014 14:54:42 -0700 Subject: [PATCH 08/36] [SPARK-3217] Add Guava to classpath when SPARK_PREPEND_CLASSES is set. When that option is used, the compiled classes from the build directory are prepended to the classpath. Now that we avoid packaging Guava, that means we have classes referencing the original Guava location in the app's classpath, so errors happen. For that case, add Guava manually to the classpath. Note: if Spark is compiled with "-Phadoop-provided", it's tricky to make things work with SPARK_PREPEND_CLASSES, because you need to add the Hadoop classpath using SPARK_CLASSPATH and that means the older Hadoop Guava overrides the newer one Spark needs. So someone using SPARK_PREPEND_CLASSES needs to remember to not use that profile. Author: Marcelo Vanzin Closes #2141 from vanzin/SPARK-3217 and squashes the following commits: b967324 [Marcelo Vanzin] [SPARK-3217] Add Guava to classpath when SPARK_PREPEND_CLASSES is set. --- bin/compute-classpath.sh | 1 + core/pom.xml | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 15c6779402994..0f63e36d8aeca 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -43,6 +43,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" diff --git a/core/pom.xml b/core/pom.xml index b2b788a4bc13b..2a81f6df289c0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -351,6 +351,33 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + package + + copy-dependencies + + + ${project.build.directory} + false + false + true + true + guava + true + + + + From 25311c2c545a60eb9dcf704814d4600987852155 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 12 Sep 2014 20:31:11 -0500 Subject: [PATCH 09/36] [SPARK-3456] YarnAllocator on alpha can lose container requests to RM Author: Thomas Graves Closes #2373 from tgravescs/SPARK-3456 and squashes the following commits: 77e9532 [Thomas Graves] [SPARK-3456] YarnAllocator on alpha can lose container requests to RM --- .../spark/deploy/yarn/YarnAllocationHandler.scala | 11 ++++++----- .../org/apache/spark/deploy/yarn/YarnAllocator.scala | 8 ++++++-- .../spark/deploy/yarn/YarnAllocationHandler.scala | 3 ++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 5a1b42c1e17d5..6c93d8582330b 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -48,16 +48,17 @@ private[yarn] class YarnAllocationHandler( private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() - override protected def allocateContainers(count: Int): YarnAllocateResponse = { + override protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse = { var resourceRequests: List[ResourceRequest] = null - logDebug("numExecutors: " + count) + logDebug("asking for additional executors: " + count + " with already pending: " + pending) + val totalNumAsk = count + pending if (count <= 0) { resourceRequests = List() } else if (preferredHostToCount.isEmpty) { logDebug("host preferences is empty") resourceRequests = List(createResourceRequest( - AllocationType.ANY, null, count, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) + AllocationType.ANY, null, totalNumAsk, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) } else { // request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. @@ -80,7 +81,7 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests: ResourceRequest = createResourceRequest( AllocationType.ANY, resource = null, - count, + totalNumAsk, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( @@ -103,7 +104,7 @@ private[yarn] class YarnAllocationHandler( req.addAllReleases(releasedContainerList) if (count > 0) { - logInfo("Allocating %d executor containers with %d of memory each.".format(count, + logInfo("Allocating %d executor containers with %d of memory each.".format(totalNumAsk, executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 0b8744f4b8bdf..299e38a5eb9c0 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,6 +112,9 @@ private[yarn] abstract class YarnAllocator( def allocateResources() = { val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + // this is needed by alpha, do it here since we add numPending right after this + val executorsPending = numPendingAllocate.get() + if (missing > 0) { numPendingAllocate.addAndGet(missing) logInfo("Will Allocate %d executor containers, each with %d memory".format( @@ -121,7 +124,7 @@ private[yarn] abstract class YarnAllocator( logDebug("Empty allocation request ...") } - val allocateResponse = allocateContainers(missing) + val allocateResponse = allocateContainers(missing, executorsPending) val allocatedContainers = allocateResponse.getAllocatedContainers() if (allocatedContainers.size > 0) { @@ -435,9 +438,10 @@ private[yarn] abstract class YarnAllocator( * * @param count Number of containers to allocate. * If zero, should still contact RM (as a heartbeat). + * @param pending Number of containers pending allocate. Only used on alpha. * @return Response to the allocation request. */ - protected def allocateContainers(count: Int): YarnAllocateResponse + protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse /** Called to release a previously allocated container. */ protected def releaseContainer(container: Container): Unit diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 5438f151ac0ad..e44a8db41b97e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -47,7 +47,8 @@ private[yarn] class YarnAllocationHandler( amClient.releaseAssignedContainer(container.getId()) } - override protected def allocateContainers(count: Int): YarnAllocateResponse = { + // pending isn't used on stable as the AMRMClient handles incremental asks + override protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse = { addResourceRequests(count) // We have already set the container request. Poll the ResourceManager for a response. From 71af030b46a89aaa9a87f18f56b9e1f1cd8ce2e7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Sep 2014 18:42:50 -0700 Subject: [PATCH 10/36] [SPARK-3094] [PySpark] compatitable with PyPy After this patch, we can run PySpark in PyPy (testing with PyPy 2.3.1 in Mac 10.9), for example: ``` PYSPARK_PYTHON=pypy ./bin/spark-submit wordcount.py ``` The performance speed up will depend on work load (from 20% to 3000%). Here are some benchmarks: Job | CPython 2.7 | PyPy 2.3.1 | Speed up ------- | ------------ | ------------- | ------- Word Count | 41s | 15s | 2.7x Sort | 46s | 44s | 1.05x Stats | 174s | 3.6s | 48x Here is the code used for benchmark: ```python rdd = sc.textFile("text") def wordcount(): rdd.flatMap(lambda x:x.split('/'))\ .map(lambda x:(x,1)).reduceByKey(lambda x,y:x+y).collectAsMap() def sort(): rdd.sortBy(lambda x:x, 1).count() def stats(): sc.parallelize(range(1024), 20).flatMap(lambda x: xrange(5024)).stats() ``` Author: Davies Liu Closes #2144 from davies/pypy and squashes the following commits: 9aed6c5 [Davies Liu] use protocol 2 in CloudPickle 4bc1f04 [Davies Liu] refactor b20ab3a [Davies Liu] pickle sys.stdout and stderr in portable way 3ca2351 [Davies Liu] Merge branch 'master' into pypy fae8b19 [Davies Liu] improve attrgetter, add tests 591f830 [Davies Liu] try to run tests with PyPy in run-tests c8d62ba [Davies Liu] cleanup f651fd0 [Davies Liu] fix tests using array with PyPy 1b98fb3 [Davies Liu] serialize itemgetter/attrgetter in portable ways 3c1dbfe [Davies Liu] Merge branch 'master' into pypy 42fb5fa [Davies Liu] Merge branch 'master' into pypy cb2d724 [Davies Liu] fix tests 9986692 [Davies Liu] Merge branch 'master' into pypy 25b4ca7 [Davies Liu] support PyPy --- python/pyspark/cloudpickle.py | 168 ++++++++++++++-------------------- python/pyspark/daemon.py | 6 +- python/pyspark/serializers.py | 10 +- python/pyspark/tests.py | 85 +++++++++++++++-- python/run-tests | 21 +++++ 5 files changed, 172 insertions(+), 118 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 80e51d1a583a0..32dda3888c62d 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,35 +52,19 @@ import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new -import dis import traceback +import platform -#relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) -GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +PyImp = platform.python_implementation() -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) import logging cloudLog = logging.getLogger("Cloud.Transport") -try: - import ctypes -except (MemoryError, ImportError): - logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) - ctypes = None - PyObject_HEAD = None -else: - - # for reading internal structures - PyObject_HEAD = [ - ('ob_refcnt', ctypes.c_size_t), - ('ob_type', ctypes.c_void_p), - ] +if PyImp == "PyPy": + # register builtin type in `new` + new.method = types.MethodType try: from cStringIO import StringIO @@ -225,6 +209,8 @@ def save_function(self, obj, name=None, pack=struct.pack): if themodule: self.modules.add(themodule) + if getattr(themodule, name, None) is obj: + return self.save_global(obj, name) if not self.savedDjangoEnv: #hack for django - if we detect the settings module, we transport it @@ -306,44 +292,28 @@ def save_function_tuple(self, func, forced_imports): # create a skeleton function object and memoize it save(_make_skel_func) - save((code, len(closure), base_globals)) + save((code, closure, base_globals)) write(pickle.REDUCE) self.memoize(func) # save the rest of the func data needed by _fill_function save(f_globals) save(defaults) - save(closure) save(dct) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(co): + def extract_code_globals(code): """ Find all globals names read or written to by codeblock co """ - code = co.co_code - names = co.co_names - out_names = set() - - n = len(code) - i = 0 - extended_arg = 0 - while i < n: - op = code[i] - - i = i+1 - if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg - extended_arg = 0 - i = i+2 - if op == EXTENDED_ARG: - extended_arg = oparg*65536L - if op in GLOBAL_OPS: - out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - return out_names + names = set(code.co_names) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType: + names |= CloudPickler.extract_code_globals(const) + return names def extract_func_data(self, func): """ @@ -354,10 +324,7 @@ def extract_func_data(self, func): # extract all global ref's func_global_refs = CloudPickler.extract_code_globals(code) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: - if type(const) is types.CodeType and const.co_names: - func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment f_globals = {} for var in func_global_refs: @@ -396,6 +363,12 @@ def get_contents(cell): return (code, f_globals, defaults, closure, dct, base_globals) + def save_builtin_function(self, obj): + if obj.__module__ is "__builtin__": + return self.save_global(obj) + return self.save_function(obj) + dispatch[types.BuiltinFunctionType] = save_builtin_function + def save_global(self, obj, name=None, pack=struct.pack): write = self.write memo = self.memo @@ -435,7 +408,7 @@ def save_global(self, obj, name=None, pack=struct.pack): try: klass = getattr(themodule, name) except AttributeError, a: - #print themodule, name, obj, type(obj) + # print themodule, name, obj, type(obj) raise pickle.PicklingError("Can't pickle builtin %s" % obj) else: raise @@ -480,7 +453,6 @@ def save_global(self, obj, name=None, pack=struct.pack): write(pickle.GLOBAL + modname + '\n' + name + '\n') self.memoize(obj) dispatch[types.ClassType] = save_global - dispatch[types.BuiltinFunctionType] = save_global dispatch[types.TypeType] = save_global def save_instancemethod(self, obj): @@ -551,23 +523,39 @@ def save_property(self, obj): dispatch[property] = save_property def save_itemgetter(self, obj): - """itemgetter serializer (needed for namedtuple support) - a bit of a pain as we need to read ctypes internals""" - class ItemGetterType(ctypes.Structure): - _fields_ = PyObject_HEAD + [ - ('nitems', ctypes.c_size_t), - ('item', ctypes.py_object) - ] - - - obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, - obj.item if obj.nitems > 1 else (obj.item,)) - - if PyObject_HEAD: + """itemgetter serializer (needed for namedtuple support)""" + class Dummy: + def __getitem__(self, item): + return item + items = obj(Dummy()) + if not isinstance(items, tuple): + items = (items, ) + return self.save_reduce(operator.itemgetter, items) + + if type(operator.itemgetter) is type: dispatch[operator.itemgetter] = save_itemgetter + def save_attrgetter(self, obj): + """attrgetter serializer""" + class Dummy(object): + def __init__(self, attrs, index=None): + self.attrs = attrs + self.index = index + def __getattribute__(self, item): + attrs = object.__getattribute__(self, "attrs") + index = object.__getattribute__(self, "index") + if index is None: + index = len(attrs) + attrs.append(item) + else: + attrs[index] = ".".join([attrs[index], item]) + return type(self)(attrs, index) + attrs = [] + obj(Dummy(attrs)) + return self.save_reduce(operator.attrgetter, tuple(attrs)) + if type(operator.attrgetter) is type: + dispatch[operator.attrgetter] = save_attrgetter def save_reduce(self, func, args, state=None, listitems=None, dictitems=None, obj=None): @@ -660,11 +648,11 @@ def save_file(self, obj): if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") - if obj.name == '': + if obj is sys.stdout: return self.save_reduce(getattr, (sys,'stdout'), obj=obj) - if obj.name == '': + if obj is sys.stderr: return self.save_reduce(getattr, (sys,'stderr'), obj=obj) - if obj.name == '': + if obj is sys.stdin: raise pickle.PicklingError("Cannot pickle standard input") if hasattr(obj, 'isatty') and obj.isatty(): raise pickle.PicklingError("Cannot pickle files that map to tty objects") @@ -873,8 +861,7 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) - -def _fill_function(func, globals, defaults, closure, dict): +def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ @@ -882,49 +869,28 @@ def _fill_function(func, globals, defaults, closure, dict): func.func_defaults = defaults func.func_dict = dict - if len(closure) != len(func.func_closure): - raise pickle.UnpicklingError("closure lengths don't match up") - for i in range(len(closure)): - _change_cell_value(func.func_closure[i], closure[i]) - return func -def _make_skel_func(code, num_closures, base_globals = None): +def _make_cell(value): + return (lambda: value).func_closure[0] + +def _reconstruct_closure(values): + return tuple([_make_cell(v) for v in values]) + +def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other func attributes (e.g. func_globals) are empty. """ - #build closure (cells): - if not ctypes: - raise Exception('ctypes failed to import; cannot build function') - - cellnew = ctypes.pythonapi.PyCell_New - cellnew.restype = ctypes.py_object - cellnew.argtypes = (ctypes.py_object,) - dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + closure = _reconstruct_closure(closures) if closures else None if base_globals is None: base_globals = {} base_globals['__builtins__'] = __builtins__ return types.FunctionType(code, base_globals, - None, None, dummy_closure) - -# this piece of opaque code is needed below to modify 'cell' contents -cell_changer_code = new.code( - 1, 1, 2, 0, - ''.join([ - chr(dis.opmap['LOAD_FAST']), '\x00\x00', - chr(dis.opmap['DUP_TOP']), - chr(dis.opmap['STORE_DEREF']), '\x00\x00', - chr(dis.opmap['RETURN_VALUE']) - ]), - (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () -) - -def _change_cell_value(cell, newval): - """ Changes the contents of 'cell' object to newval """ - return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + None, None, closure) + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 22ab8d30c0ae3..15445abf67147 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -42,10 +42,6 @@ def worker(sock): """ Called by a worker process after the fork(). """ - # Redirect stdout to stderr - os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) @@ -102,6 +98,7 @@ def manager(): listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) + sys.stdout.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -115,7 +112,6 @@ def handle_sigterm(*args): signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Initialization complete - sys.stdout.close() try: while True: try: diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7b2710b913128..a5f9341e819a9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -355,7 +355,8 @@ class PickleSerializer(FramedSerializer): def dumps(self, obj): return cPickle.dumps(obj, 2) - loads = cPickle.loads + def loads(self, obj): + return cPickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -374,8 +375,11 @@ class MarshalSerializer(FramedSerializer): This serializer is faster than PickleSerializer but supports fewer datatypes. """ - dumps = marshal.dumps - loads = marshal.loads + def dumps(self, obj): + return marshal.dumps(obj) + + def loads(self, obj): + return marshal.loads(obj) class AutoSerializer(FramedSerializer): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb84ebe72cb24..2e7c2750a8bb6 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,6 +31,7 @@ import time import zipfile import random +from platform import python_implementation if sys.version_info[:2] <= (2, 6): import unittest2 as unittest @@ -41,7 +42,8 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType @@ -168,15 +170,46 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEquals(p1, p2) - -# Regression test for SPARK-3415 -class CloudPickleTest(unittest.TestCase): + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 def test_pickling_file_handles(self): - from pyspark.cloudpickle import dumps - from StringIO import StringIO - from pickle import load + ser = CloudPickleSerializer() out1 = sys.stderr - out2 = load(StringIO(dumps(out1))) + out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) @@ -861,8 +894,42 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) - @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.Text", + "mapred.output.dir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + @unittest.skipIf(sys.version_info[:2] <= (2, 6) or python_implementation() == "PyPy", + "Skipped on 2.6 and PyPy until SPARK-2951 is fixed") + def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays array_data = [(1, array('d')), diff --git a/python/run-tests b/python/run-tests index d98840de59d2c..a67e5a99fbdcc 100755 --- a/python/run-tests +++ b/python/run-tests @@ -85,6 +85,27 @@ run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" +# Try to test with PyPy +if [ $(which pypy) ]; then + export PYSPARK_PYTHON="pypy" + echo "Testing with PyPy version:" + $PYSPARK_PYTHON --version + + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + run_test "pyspark/sql.py" + # These tests are included in the module-level docs, and so must + # be handled on a higher level rather than within the python file. + export PYSPARK_DOC_TEST=1 + run_test "pyspark/broadcast.py" + run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + unset PYSPARK_DOC_TEST + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +fi + if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green echo "Tests passed." From 885d1621bc06bc1f009c9707c3452eac26baf828 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Sep 2014 19:05:39 -0700 Subject: [PATCH 11/36] [SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu Closes #2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd --- python/pyspark/sql.py | 38 ++++++++++++++++++-------------------- python/pyspark/tests.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 53eea6d6cf3ba..fc9310fef318c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1122,7 +1122,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) - return SchemaRDD(srdd, self) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1134,8 +1134,8 @@ def registerRDDAsTable(self, rdd, tableName): >>> sqlCtx.registerRDDAsTable(srdd, "table1") """ if (rdd.__class__ is SchemaRDD): - jschema_rdd = rdd._jschema_rdd - self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) + srdd = rdd._jschema_rdd.baseSchemaRDD() + self._ssql_ctx.registerRDDAsTable(srdd, tableName) else: raise ValueError("Can only register SchemaRDD as table") @@ -1151,7 +1151,7 @@ def parquetFile(self, path): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) + jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None): @@ -1207,11 +1207,11 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - jschema_rdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonFile(path, scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. @@ -1275,11 +1275,11 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -1290,7 +1290,7 @@ def sql(self, sqlQuery): >>> srdd2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) def table(self, tableName): """Returns the specified table as a L{SchemaRDD}. @@ -1301,7 +1301,7 @@ def table(self, tableName): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1353,7 +1353,7 @@ def hiveql(self, hqlQuery): warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", DeprecationWarning) - return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) + return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) def hql(self, hqlQuery): """ @@ -1524,6 +1524,8 @@ class SchemaRDD(RDD): def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc + clsName = jschema_rdd.getClass().getName() + assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" self._jschema_rdd = jschema_rdd self._id = None self.is_cached = False @@ -1540,7 +1542,7 @@ def _jrdd(self): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.javaToPython() + self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() return self._lazy_jrdd def id(self): @@ -1598,7 +1600,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.schema().toString()) + return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) def schemaString(self): """Returns the output schema in the tree format.""" @@ -1649,8 +1651,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) schema = self.schema() - import pickle - pickle.loads(pickle.dumps(schema)) def applySchema(_, it): cls = _create_cls(schema) @@ -1687,10 +1687,8 @@ def isCheckpointed(self): def getCheckpointFile(self): checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): + if checkpointFile.isPresent(): return checkpointFile.get() - else: - return None def coalesce(self, numPartitions, shuffle=False): rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 2e7c2750a8bb6..b687d695b01c4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -607,6 +607,34 @@ def test_broadcast_in_udf(self): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + srdd = self.sqlCtx.jsonRDD(rdd) + srdd.count() + srdd.collect() + srdd.schemaString() + srdd.schema() + + # cache and checkpoint + self.assertFalse(srdd.is_cached) + srdd.persist() + srdd.unpersist() + srdd.cache() + self.assertTrue(srdd.is_cached) + self.assertFalse(srdd.isCheckpointed()) + self.assertEqual(None, srdd.getCheckpointFile()) + + srdd = srdd.coalesce(2, True) + srdd = srdd.repartition(3) + srdd = srdd.distinct() + srdd.intersection(srdd) + self.assertEqual(2, srdd.count()) + + srdd.registerTempTable("temp") + srdd = self.sqlCtx.sql("select foo from temp") + srdd.count() + srdd.collect() + class TestIO(PySparkTestCase): From 6d887db7891be643f0131b136e82191b5f6eb407 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 12 Sep 2014 20:14:09 -0700 Subject: [PATCH 12/36] [SPARK-3515][SQL] Moves test suite setup code to beforeAll rather than in constructor Please refer to the JIRA ticket for details. **NOTE** We should check all test suites that do similar initialization-like side effects in their constructors. This PR only fixes `ParquetMetastoreSuite` because it breaks our Jenkins Maven build. Author: Cheng Lian Closes #2375 from liancheng/say-no-to-constructor and squashes the following commits: 0ceb75b [Cheng Lian] Moves test suite setup code to beforeAll rather than in constructor --- .../sql/parquet/ParquetMetastoreSuite.scala | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala index 0723be7298e15..e380280f301c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -20,14 +20,10 @@ package org.apache.spark.sql.parquet import java.io.File -import org.apache.spark.sql.hive.execution.HiveTableScan import org.scalatest.BeforeAndAfterAll -import scala.reflect.ClassTag - -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ case class ParquetData(intField: Int, stringField: String) @@ -36,27 +32,19 @@ case class ParquetData(intField: Int, stringField: String) * Tests for our SerDe -> Native parquet scan conversion. */ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { - override def beforeAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - val partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - sql(s""" + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" create external table partitioned_parquet ( intField INT, @@ -70,7 +58,7 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${partitionedTableDir.getCanonicalPath}' """) - sql(s""" + sql(s""" create external table normal_parquet ( intField INT, @@ -83,8 +71,15 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' """) - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") } test("project the partitioning column") { From 2584ea5b23b1c5a4df9549b94bfc9b8e0900532e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 21:55:39 -0700 Subject: [PATCH 13/36] [SPARK-3469] Make sure all TaskCompletionListener are called even with failures This is necessary because we rely on this callback interface to clean resources up. The old behavior would lead to resource leaks. Note that this also changes the fault semantics of TaskCompletionListener. Previously failures in TaskCompletionListeners would result in the task being reported immediately. With this change, we report the exception at the end, and the reported exception is a TaskCompletionListenerException that contains all the exception messages. Author: Reynold Xin Closes #2343 from rxin/taskcontext-callback and squashes the following commits: a3845b2 [Reynold Xin] Mark TaskCompletionListenerException as private[spark]. ac5baea [Reynold Xin] Removed obsolete comment. aa68ea4 [Reynold Xin] Throw an exception if task completion callback fails. 29b6162 [Reynold Xin] oops compilation failed. 1cb444d [Reynold Xin] [SPARK-3469] Call all TaskCompletionListeners even if some fail. --- .../scala/org/apache/spark/TaskContext.scala | 18 ++++++++-- .../TaskCompletionListenerException.scala | 34 +++++++++++++++++++ .../spark/scheduler/TaskContextSuite.scala | 22 ++++++++++-- 3 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2b99b8a5af250..51b3e4d5e0936 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} /** @@ -41,7 +41,7 @@ class TaskContext( val attemptId: Long, val runningLocally: Boolean = false, private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable { + extends Serializable with Logging { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -103,8 +103,20 @@ class TaskContext( /** Marks the task as completed and triggers the listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true + val errorMsgs = new ArrayBuffer[String](2) // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } } /** Marks the task for interruption, i.e. cancellation. */ diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala new file mode 100644 index 0000000000000..f64e069cd1724 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Exception thrown when there is an exception in + * executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index db2ad829a48f9..faba5508c906c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler +import org.mockito.Mockito._ +import org.mockito.Matchers.any + import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} + class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - test("Calls executeOnCompleteCallbacks after failure") { + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { @@ -45,6 +49,20 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } assert(TaskContextSuite.completed === true) } + + test("all TaskCompletionListeners should be called even if some fail") { + val context = new TaskContext(0, 0, 0) + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + + intercept[TaskCompletionListenerException] { + context.markTaskCompleted() + } + + verify(listener, times(1)).onTaskCompletion(any()) + } } private object TaskContextSuite { From e11eeb71fa3a5fe7ddacb94d5b93b173d4d901a8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 12 Sep 2014 21:58:02 -0700 Subject: [PATCH 14/36] [SQL][Docs] Update SQL programming guide to show the correct default value of containsNull in an ArrayType After #1889, the default value of `containsNull` in an `ArrayType` is `true`. Author: Yin Huai Closes #2374 from yhuai/containsNull and squashes the following commits: dc609a3 [Yin Huai] Update the SQL programming guide to show the correct default value of containsNull in an ArrayType (the default value is true instead of false). --- docs/sql-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d83efa4bab324..3159d52787d5a 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1232,7 +1232,7 @@ import org.apache.spark.sql._ scala.collection.Seq ArrayType(elementType, [containsNull])
- Note: The default value of containsNull is false. + Note: The default value of containsNull is true. @@ -1358,7 +1358,7 @@ please use factory methods provided in java.util.List DataType.createArrayType(elementType)
- Note: The value of containsNull will be false
+ Note: The value of containsNull will be true
DataType.createArrayType(elementType, containsNull). @@ -1505,7 +1505,7 @@ from pyspark.sql import * list, tuple, or array ArrayType(elementType, [containsNull])
- Note: The default value of containsNull is False. + Note: The default value of containsNull is True. From feaa3706f17e44efcdac9f0a543a5b91232771ce Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 12 Sep 2014 22:50:37 -0700 Subject: [PATCH 15/36] SPARK-3470 [CORE] [STREAMING] Add Closeable / close() to Java context objects ... that expose a stop() lifecycle method. This doesn't add `AutoCloseable`, which is Java 7+ only. But it should be possible to use try-with-resources on a `Closeable` in Java 7, as long as the `close()` does not throw a checked exception, and these don't. Q.E.D. Author: Sean Owen Closes #2346 from srowen/SPARK-3470 and squashes the following commits: 612c21d [Sean Owen] Add Closeable / close() to Java context objects that expose a stop() lifecycle method --- .../scala/org/apache/spark/api/java/JavaSparkContext.scala | 7 ++++++- .../spark/streaming/api/java/JavaStreamingContext.scala | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8e178bc8480f7..23f7e6be81a90 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.java +import java.io.Closeable import java.util import java.util.{Map => JMap} @@ -40,7 +41,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. */ -class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { +class JavaSparkContext(val sc: SparkContext) + extends JavaSparkContextVarargsWorkaround with Closeable { + /** * Create a JavaSparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -534,6 +537,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork sc.stop() } + override def close(): Unit = stop() + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 18605cac7006c..9dc26dc6b32a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ package org.apache.spark.streaming.api.java import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.io.InputStream +import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} import akka.actor.{Props, SupervisorStrategy} @@ -49,7 +49,7 @@ import org.apache.spark.streaming.receiver.Receiver * respectively. `context.awaitTransformation()` allows the current thread to wait for the * termination of a context by `stop()` or by an exception. */ -class JavaStreamingContext(val ssc: StreamingContext) { +class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a StreamingContext. @@ -540,6 +540,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { ssc.stop(stopSparkContext, stopGracefully) } + + override def close(): Unit = stop() + } /** From b4dded40fbecb485f1ddfd8316b44d42a1554d64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 22:51:25 -0700 Subject: [PATCH 16/36] Proper indent for the previous commit. --- .../main/scala/org/apache/spark/api/java/JavaSparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 23f7e6be81a90..791d853a015a1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. */ class JavaSparkContext(val sc: SparkContext) - extends JavaSparkContextVarargsWorkaround with Closeable { + extends JavaSparkContextVarargsWorkaround with Closeable { /** * Create a JavaSparkContext that loads settings from system properties (for instance, when From a523ceaf159733dabcef84c7adc1463546679f65 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sat, 13 Sep 2014 12:34:20 -0700 Subject: [PATCH 17/36] [SQL] [Docs] typo fixes * Fixed random typo * Added in missing description for DecimalType Author: Nicholas Chammas Closes #2367 from nchammas/patch-1 and squashes the following commits: aa528be [Nicholas Chammas] doc fix for SQL DecimalType 3247ac1 [Nicholas Chammas] [SQL] [Docs] typo fixes --- docs/sql-programming-guide.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3159d52787d5a..8d41fdec699e9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -918,7 +918,6 @@ options. ## Migration Guide for Shark User ### Scheduling -s To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: @@ -1110,7 +1109,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c The range of numbers is from `-9223372036854775808` to `9223372036854775807`. - `FloatType`: Represents 4-byte single-precision floating point numbers. - `DoubleType`: Represents 8-byte double-precision floating point numbers. - - `DecimalType`: + - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. * Binary type From 184cd51c4207c23726da97f907f2d912a5a44845 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 13 Sep 2014 12:35:40 -0700 Subject: [PATCH 18/36] [SPARK-3481][SQL] Removes the evil MINOR HACK This is a follow up of #2352. Now we can finally remove the evil "MINOR HACK", which covered up the eldest bug in the history of Spark SQL (see details [here](https://github.com/apache/spark/pull/2352#issuecomment-55440621)). Author: Cheng Lian Closes #2377 from liancheng/remove-evil-minor-hack and squashes the following commits: 0869c78 [Cheng Lian] Removes the evil MINOR HACK --- .../spark/sql/hive/execution/HiveComparisonTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 671c3b162f875..79cc7a3fcc7d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -250,9 +250,9 @@ abstract class HiveComparisonTest } try { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") - if (reset) { TestHive.reset() } + if (reset) { + TestHive.reset() + } val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => From 74049249abb952ad061c0e221c22ff894a9e9c8d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 13 Sep 2014 15:08:30 -0700 Subject: [PATCH 19/36] [SPARK-3294][SQL] Eliminates boxing costs from in-memory columnar storage This is a major refactoring of the in-memory columnar storage implementation, aims to eliminate boxing costs from critical paths (building/accessing column buffers) as much as possible. The basic idea is to refactor all major interfaces into a row-based form and use them together with `SpecificMutableRow`. The difficult part is how to adapt all compression schemes, esp. `RunLengthEncoding` and `DictionaryEncoding`, to this design. Since in-memory compression is disabled by default for now, and this PR should be strictly better than before no matter in-memory compression is enabled or not, maybe I'll finish that part in another PR. **UPDATE** This PR also took the chance to optimize `HiveTableScan` by 1. leveraging `SpecificMutableRow` to avoid boxing cost, and 1. building specific `Writable` unwrapper functions a head of time to avoid per row pattern matching and branching costs. TODO - [x] Benchmark - [ ] ~~Eliminate boxing costs in `RunLengthEncoding`~~ (left to future PRs) - [ ] ~~Eliminate boxing costs in `DictionaryEncoding` (seems not easy to do without specializing `DictionaryEncoding` for every supported column type)~~ (left to future PRs) ## Micro benchmark The benchmark uses a 10 million line CSV table consists of bytes, shorts, integers, longs, floats and doubles, measures the time to build the in-memory version of this table, and the time to scan the whole in-memory table. Benchmark code can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-hivetablescanbenchmark-scala). Script used to generate the input table can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-tablegen-scala). Speedup: - Hive table scanning + column buffer building: **18.74%** The original benchmark uses 1K as in-memory batch size, when increased to 10K, it can be 28.32% faster. - In-memory table scanning: **7.95%** Before: | Building | Scanning ------- | -------- | -------- 1 | 16472 | 525 2 | 16168 | 530 3 | 16386 | 529 4 | 16184 | 538 5 | 16209 | 521 Average | 16283.8 | 528.6 After: | Building | Scanning ------- | -------- | -------- 1 | 13124 | 458 2 | 13260 | 529 3 | 12981 | 463 4 | 13214 | 483 5 | 13583 | 500 Average | 13232.4 | 486.6 Author: Cheng Lian Closes #2327 from liancheng/prevent-boxing/unboxing and squashes the following commits: 4419fe4 [Cheng Lian] Addressing comments e5d2cf2 [Cheng Lian] Bug fix: should call setNullAt when field value is null to avoid NPE 8b8552b [Cheng Lian] Only checks for partition batch pruning flag once 489f97b [Cheng Lian] Bug fix: TableReader.fillObject uses wrong ordinals 97bbc4e [Cheng Lian] Optimizes hive.TableReader by by providing specific Writable unwrappers a head of time 3dc1f94 [Cheng Lian] Minor changes to eliminate row object creation 5b39cb9 [Cheng Lian] Lowers log level of compression scheme details f2a7890 [Cheng Lian] Use SpecificMutableRow in InMemoryColumnarTableScan to avoid boxing 9cf30b0 [Cheng Lian] Added row based ColumnType.append/extract 456c366 [Cheng Lian] Made compression decoder row based edac3cd [Cheng Lian] Makes ColumnAccessor.extractSingle row based 8216936 [Cheng Lian] Removes boxing cost in IntDelta and LongDelta by providing specialized implementations b70d519 [Cheng Lian] Made some in-memory columnar storage interfaces row-based --- .../catalyst/expressions/SpecificRow.scala | 2 +- .../spark/sql/columnar/ColumnAccessor.scala | 8 +- .../spark/sql/columnar/ColumnBuilder.scala | 27 +- .../spark/sql/columnar/ColumnStats.scala | 16 +- .../spark/sql/columnar/ColumnType.scala | 178 ++++++++++-- .../columnar/InMemoryColumnarTableScan.scala | 92 +++--- .../sql/columnar/NullableColumnAccessor.scala | 4 +- .../sql/columnar/NullableColumnBuilder.scala | 8 +- .../CompressibleColumnAccessor.scala | 7 +- .../CompressibleColumnBuilder.scala | 24 +- .../compression/CompressionScheme.scala | 16 +- .../compression/compressionSchemes.scala | 264 +++++++++++------- .../spark/sql/columnar/ColumnStatsSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 11 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 6 +- .../compression/BooleanBitSetSuite.scala | 7 +- .../compression/DictionaryEncodingSuite.scala | 9 +- .../compression/IntegralDeltaSuite.scala | 9 +- .../compression/RunLengthEncodingSuite.scala | 9 +- .../apache/spark/sql/hive/TableReader.scala | 119 ++++---- .../sql/hive/execution/HiveQuerySuite.scala | 18 +- 24 files changed, 554 insertions(+), 292 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 088f11ee4aa53..9cbab3d5d0d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -171,7 +171,7 @@ final class MutableByte extends MutableValue { } final class MutableAny extends MutableValue { - var value: Any = 0 + var value: Any = _ def boxed = if (isNull) null else value def update(v: Any) = value = { isNull = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 42a5a9a84f362..c9faf0852142a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( def hasNext = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int) { - columnType.setField(row, ordinal, extractSingle(buffer)) + def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) } - def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } protected def underlyingBuffer = buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index b3ec5ded22422..2e61a981375aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) - columnType.append(field, buffer) + override def appendFrom(row: Row, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) } override def build() = { @@ -142,16 +141,16 @@ private[sql] object ColumnBuilder { useCompression: Boolean = false): ColumnBuilder = { val builder = (typeId match { - case INT.typeId => new IntColumnBuilder - case LONG.typeId => new LongColumnBuilder - case FLOAT.typeId => new FloatColumnBuilder - case DOUBLE.typeId => new DoubleColumnBuilder - case BOOLEAN.typeId => new BooleanColumnBuilder - case BYTE.typeId => new ByteColumnBuilder - case SHORT.typeId => new ShortColumnBuilder - case STRING.typeId => new StringColumnBuilder - case BINARY.typeId => new BinaryColumnBuilder - case GENERIC.typeId => new GenericColumnBuilder + case INT.typeId => new IntColumnBuilder + case LONG.typeId => new LongColumnBuilder + case FLOAT.typeId => new FloatColumnBuilder + case DOUBLE.typeId => new DoubleColumnBuilder + case BOOLEAN.typeId => new BooleanColumnBuilder + case BYTE.typeId => new ByteColumnBuilder + case SHORT.typeId => new ShortColumnBuilder + case STRING.typeId => new StringColumnBuilder + case BINARY.typeId => new BinaryColumnBuilder + case GENERIC.typeId => new GenericColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index fc343ccb995c2..203a714e03c97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats { var lower = Byte.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value @@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats { var lower = Short.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value @@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats { var lower = Long.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value @@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { var lower = Double.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value @@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats { var lower = Float.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value @@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats { var lower = Int.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value @@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats { var lower: String = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value @@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { var lower: Timestamp = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9a61600115872..198b5756676aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import java.sql.Timestamp - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ @@ -46,16 +45,33 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def extract(buffer: ByteBuffer): JvmType + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + /** * Appends the given value v of type T into the given ByteBuffer. */ - def append(v: JvmType, buffer: ByteBuffer) + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } /** - * Returns the size of the value. This is used to calculate the size of variable length types - * such as byte arrays and strings. + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. */ - def actualSize(v: JvmType): Int = defaultSize + def actualSize(row: Row, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs @@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType) + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to(toOrdinal) = from(fromOrdinal) + } /** * Creates a duplicated copy of the value. @@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer) { + def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + def extract(buffer: ByteBuffer) = { buffer.getInt() } - override def setField(row: MutableRow, ordinal: Int, value: Int) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, buffer.getInt()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { - override def append(v: Long, buffer: ByteBuffer) { + override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getLong() } - override def setField(row: MutableRow, ordinal: Int, value: Long) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, buffer.getLong()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { - override def append(v: Float, buffer: ByteBuffer) { + override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getFloat() } - override def setField(row: MutableRow, ordinal: Int, value: Float) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, buffer.getFloat()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { - override def append(v: Double, buffer: ByteBuffer) { + override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getDouble() } - override def setField(row: MutableRow, ordinal: Int, value: Double) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, buffer.getDouble()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { - override def append(v: Boolean, buffer: ByteBuffer) { - buffer.put(if (v) 1.toByte else 0.toByte) + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } override def extract(buffer: ByteBuffer) = buffer.get() == 1 - override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { - override def append(v: Byte, buffer: ByteBuffer) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.get() } - override def setField(row: MutableRow, ordinal: Int, value: Byte) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { - override def append(v: Short, buffer: ByteBuffer) { + override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getShort() } - override def setField(row: MutableRow, ordinal: Int, value: Short) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4 + override def actualSize(row: Row, ordinal: Int): Int = { + row.getString(ordinal).getBytes("utf-8").length + 4 + } - override def append(v: String, buffer: ByteBuffer) { + override def append(v: String, buffer: ByteBuffer): Unit = { val stringBytes = v.getBytes("utf-8") buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } @@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { new String(stringBytes, "utf-8") } - override def setField(row: MutableRow, ordinal: Int, value: String) { + override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { row.setString(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setString(toOrdinal, from.getString(fromOrdinal)) + } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { @@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { timestamp } - override def append(v: Timestamp, buffer: ByteBuffer) { + override def append(v: Timestamp, buffer: ByteBuffer): Unit = { buffer.putLong(v.getTime).putInt(v.getNanos) } @@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { row(ordinal).asInstanceOf[Timestamp] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp) { + override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } } @@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(v: Array[Byte]) = v.length + 4 + override def actualSize(row: Row, ordinal: Int) = { + getField(row, ordinal).length + 4 + } - override def append(v: Array[Byte], buffer: ByteBuffer) { + override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { buffer.putInt(v.length).put(v, 0, v.length) } @@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eab2f23c18e1..8a3612cdf19be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { baseIterator => + val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => @@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation( ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray - var row: Row = null var rowCount = 0 - - while (baseIterator.hasNext && rowCount < batchSize) { - row = baseIterator.next() + while (rowIterator.hasNext && rowCount < batchSize) { + val row = rowIterator.next() var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build()), stats) } - def hasNext = baseIterator.hasNext + def hasNext = rowIterator.hasNext } }.cache() @@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan( } } + // Accumulators used for testing purposes val readPartitions = sparkContext.accumulator(0) val readBatches = sparkContext.accumulator(0) @@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan( readPartitions.setValue(0) readBatches.setValue(0) - relation.cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), relation.partitionStatistics.schema) - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = if (attributes.isEmpty) { - Seq(0) + // Find the ordinals and data types of the requested columns. If none are requested, use the + // narrowest (the field with minimum default element size). + val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { + val (narrowestOrdinal, narrowestDataType) = + relation.output.zipWithIndex.map { case (a, ordinal) => + ordinal -> a.dataType + } minBy { case (_, dataType) => + ColumnType(dataType).defaultSize + } + Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + attributes.map { a => + relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + }.unzip } - val rows = iterator - // Skip pruned batches - .filter { cachedBatch => - if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") - logInfo(s"Skipping partition based on stats $statsString") - false - } else { - readBatches += 1 - true - } - } - // Build column accessors - .map { cachedBatch => - requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) - } - // Extract rows via column accessors - .flatMap { columnAccessors => - val nextRow = new GenericMutableRow(columnAccessors.length) + val nextRow = new SpecificMutableRow(requestedColumnDataTypes) + + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + val rows = cacheBatches.flatMap { cachedBatch => + // Build column accessors + val columnAccessors = + requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + + // Extract rows via column accessors new Iterator[Row] { override def next() = { var i = 0 @@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors(0).hasNext } } - if (rows.hasNext) { - readPartitions += 1 + if (rows.hasNext) { + readPartitions += 1 + } + + rows } - rows + // Do partition batch pruning if enabled + val cachedBatchesToScan = + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + } else { + cachedBatchIterator + } + + cachedBatchesToRows(cachedBatchesToScan) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index b7f8826861a2c..965782a40031b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override protected def initialize() { + abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 @@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int) { + abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index a72970eef7aa4..f1f494ac26d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { protected var nullCount: Int = _ private var pos: Int = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + nulls = ByteBuffer.allocate(1024) nulls.order(ByteOrder.nativeOrder()) pos = 0 @@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index b4120a3d4368b..27ac5f4dbdbbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar.compression -import java.nio.ByteBuffer - +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} @@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc abstract override def hasNext = super.hasNext || decoder.hasNext - override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() + override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + decoder.next(row, ordinal) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index a5826bb033e41..628d9cec41d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] var compressionEncoders: Seq[Encoder[T]] = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + compressionEncoders = if (useCompression) { - schemes.filter(_.supports(columnType)).map(_.encoder[T]) + schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType)) } else { - Seq(PassThrough.encoder) + Seq(PassThrough.encoder(columnType)) } super.initialize(initialSize, columnName, useCompression) } @@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - + private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(field, columnType) + compressionEncoders(i).gatherCompressibilityStats(row, ordinal) i += 1 } } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) @@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) - if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType) } // Header = column type ID + null count + null positions @@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] .putInt(nullCount) .put(nulls) - logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(nonNullBuffer, compressedBuffer, columnType) + logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 7797f75177893..acb06cb5376b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.columnar.compression -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} private[sql] trait Encoder[T <: NativeType] { - def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {} + def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] { if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 } - def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer + def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] +private[sql] trait Decoder[T <: NativeType] { + def next(row: MutableRow, ordinal: Int): Unit + + def hasNext: Boolean +} private[sql] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType]: Encoder[T] + def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8cf9ec74ca2de..29edcf17242c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -23,7 +23,8 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar._ import org.apache.spark.util.Utils @@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]) = true - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize = 0 override def compressedSize = 0 - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme { class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next() = columnType.extract(buffer) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } override def hasNext = buffer.hasRemaining } @@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) @@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new GenericMutableRow(1) + private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize = _uncompressedSize override def compressedSize = _compressedSize - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { - val actualSize = columnType.actualSize(value) + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize if (lastValue.isNullAt(0)) { - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 _compressedSize += actualSize + 4 } else { @@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { lastRun += 1 } else { _compressedSize += actualSize + 4 - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 } } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - var currentValue = columnType.extract(from) + val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) var currentRun = 1 + val value = new SpecificMutableRow(Seq(columnType.dataType)) + + columnType.extract(from, currentValue, 0) while (from.hasRemaining) { - val value = columnType.extract(from) + columnType.extract(from, value, 0) - if (value == currentValue) { + if (value.head == currentValue.head) { currentRun += 1 } else { // Writes current run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) // Resets current run - currentValue = value + columnType.copyField(value, 0, currentValue, 0) currentRun = 1 } } // Writes the last run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) } @@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#JvmType = _ - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = buffer.getInt() @@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { valueCount += 1 } - currentValue + columnType.setField(row, ordinal, currentValue) } override def hasNext = valueCount < run || buffer.hasRemaining @@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def supports(columnType: ColumnType[_, _]) = columnType match { case INT | LONG | STRING => true case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + if (!overflow) { - val actualSize = columnType.actualSize(value) + val actualSize = columnType.actualSize(row, ordinal) count += 1 _uncompressedSize += actualSize @@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def next() = dictionary(buffer.getShort()) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.setField(row, ordinal, dictionary(buffer.getShort())) + } override def hasNext = buffer.hasRemaining } @@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new this.Encoder).asInstanceOf[compression.Encoder[T]] + } override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats( - value: Boolean, - columnType: NativeColumnType[BooleanType.type]) { - + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } - override def compress( - from: ByteBuffer, - to: ByteBuffer, - columnType: NativeColumnType[BooleanType.type]) = { - + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(): Boolean = { + override def next(row: MutableRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme { currentWord = buffer.getLong() } - ((currentWord >> bit) & 1) != 0 + row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) } override def hasNext: Boolean = visited < count } } -private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme { +private[sql] case object IntDelta extends CompressionScheme { + override def typeId: Int = 4 + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { - new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]]) - .asInstanceOf[compression.Decoder[T]] + new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] - - /** - * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or - * `(false, 0: Byte)` otherwise. - */ - protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte) + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - /** - * Simply computes `x + delta` - */ - protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType + override def supports(columnType: ColumnType[_, _]) = columnType == INT - class Encoder extends compression.Encoder[I] { - private var _compressedSize: Int = 0 + class Encoder extends compression.Encoder[IntegerType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 - private var _uncompressedSize: Int = 0 + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize - private var prev: I#JvmType = _ + private var prevValue: Int = _ - private var initial = true + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getInt(ordinal) + val delta = value - prevValue - override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) { - _uncompressedSize += columnType.defaultSize + _compressedSize += 1 - if (initial) { - initial = false - _compressedSize += 1 + columnType.defaultSize - } else { - val (smallEnough, _) = byteSizedDelta(value, prev) - _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize) + // If this is the first integer to be compressed, or the delta is out of byte range, then give + // up compressing this integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += INT.defaultSize } - prev = value + _uncompressedSize += INT.defaultSize + prevValue = value } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(typeId) if (from.hasRemaining) { - var prev = columnType.extract(from) + var prev = from.getInt() to.put(Byte.MinValue) - columnType.append(prev, to) + to.putInt(prev) while (from.hasRemaining) { - val current = columnType.extract(from) - val (smallEnough, delta) = byteSizedDelta(current, prev) + val current = from.getInt() + val delta = current - prev prev = current - if (smallEnough) { - to.put(delta) + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) } else { to.put(Byte.MinValue) - columnType.append(current, to) + to.putInt(current) } } } - to.rewind() - to + to.rewind().asInstanceOf[ByteBuffer] } - - override def uncompressedSize = _uncompressedSize - - override def compressedSize = _compressedSize } - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I]) - extends compression.Decoder[I] { + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type]) + extends compression.Decoder[IntegerType.type] { + + private var prev: Int = _ - private var prev: I#JvmType = _ + override def hasNext: Boolean = buffer.hasRemaining - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer) - prev + prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + row.setInt(ordinal, prev) } - - override def hasNext = buffer.hasRemaining } } -private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] { - override val typeId = 4 +private[sql] case object LongDelta extends CompressionScheme { + override def typeId: Int = 5 - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - override protected def addDelta(x: Int, delta: Byte) = x + delta + override def supports(columnType: ColumnType[_, _]) = columnType == LONG + + class Encoder extends compression.Encoder[LongType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 + + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize + + private var prevValue: Long = _ + + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getLong(ordinal) + val delta = value - prevValue + + _compressedSize += 1 - override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + // If this is the first long integer to be compressed, or the delta is out of byte range, then + // give up compressing this long integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += LONG.defaultSize + } + + _uncompressedSize += LONG.defaultSize + prevValue = value + } + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + to.putInt(typeId) + + if (from.hasRemaining) { + var prev = from.getLong() + to.put(Byte.MinValue) + to.putLong(prev) + + while (from.hasRemaining) { + val current = from.getLong() + val delta = current - prev + prev = current + + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) + } else { + to.put(Byte.MinValue) + to.putLong(current) + } + } + } + + to.rewind().asInstanceOf[ByteBuffer] + } } -} -private[sql] case object LongDelta extends IntegralDelta[LongType.type] { - override val typeId = 5 + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type]) + extends compression.Decoder[LongType.type] { - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + private var prev: Long = _ - override protected def addDelta(x: Long, delta: Byte) = x + delta + override def hasNext: Boolean = buffer.hasRemaining - override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + override def next(row: MutableRow, ordinal: Int): Unit = { + val delta = buffer.get() + prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + row.setLong(ordinal, prev) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index cde91ceb68c98..0cdbb3167ce36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite { def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row) { + initialStatistics: Row): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 75f653f3280bd..4fb1ecf1d532b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging { def checkActualSize[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], value: JvmType, - expected: Int) { + expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - columnType.actualSize(value) + val row = new GenericMutableRow(1) + columnType.setField(row, 0, value) + columnType.actualSize(row, 0) } } @@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType) { + getter: (ByteBuffer) => T#JvmType): Unit = { testColumnType[T, T#JvmType](columnType, putter, getter) } @@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType) { + getter: (ByteBuffer) => JvmType): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val seq = (0 until 4).map(_ => makeRandomValue(columnType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 0e3c67f5eed29..c1278248ef655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLConf, QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 3baa6f8ec0c83..6c9a9ab6c3418 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnAccessor[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index a77262534a352..f54a21eb4fbb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 5d2fd4959197c..69e0adbd3ee0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning - override protected def beforeAll() { + override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) @@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } - override protected def afterAll() { + override protected def afterAll(): Unit = { setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } @@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be filter: String, expectedQueryResult: Seq[Int], expectedReadPartitions: Int, - expectedReadBatches: Int) { + expectedReadBatches: Int): Unit = { test(filter) { val query = sql(s"SELECT * FROM intData WHERE $filter") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index e01cc8b4d20f2..d9e488e0ffd16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val mutableRow = new GenericMutableRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + mutableRow.getBoolean(0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d2969d906c943..1cdb909146d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) // 4 extra bytes for dictionary size - val dictionarySize = 4 + values.map(columnType.actualSize).sum + val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum // 2 bytes for each `Short` val compressedSize = 4 + dictionarySize + 2 * inputSeq.length // 4 extra bytes for compression scheme type ID @@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 322f447c24840..73f31c0233343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite { def testIntegralDelta[I <: IntegralType]( columnStats: ColumnStats, columnType: NativeColumnType[I], - scheme: IntegralDelta[I]) { + scheme: CompressionScheme) { def skeleton(input: Seq[I#JvmType]) { // ------------- @@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) + if (input.nonEmpty) { input.foreach{ assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 218c09ac26362..4ce2552112c92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite { // Compression scheme ID + compressed contents val compressedSize = 4 + inputRuns.map { case (index, _) => // 4 extra bytes each run for run length - columnType.actualSize(values(index)) + 4 + columnType.actualSize(rows(index), 0) + 4 }.sum // 4 extra bytes for compression scheme type ID @@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 329f80cad471e..84fafcde63d05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector - +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.expressions._ /** * A trait for subclasses that handle table scans. @@ -108,12 +106,12 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new GenericMutableRow(attrsWithIndex.length) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } @@ -164,33 +162,32 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer - val mutableRow = new GenericMutableRow(attributes.length) - - // split the attributes (output schema) into 2 categories: - // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the - // index of the attribute in the output Row. - val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { - relation.partitionKeys.indexOf(attr._1) >= 0 - }) - - def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { - partitionKeys.foreach { case (attr, ordinal) => - // get partition key ordinal for a given attribute - val partOridinal = relation.partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + + // Splits all attributes into two groups, partition key attributes and those that are not. + // Attached indices indicate the position of each attribute in the output schema. + val (partitionKeyAttrs, nonPartitionKeyAttrs) = + attributes.zipWithIndex.partition { case (attr, _) => + relation.partitionKeys.contains(attr) + } + + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + partitionKeyAttrs.foreach { case (attr, ordinal) => + val partOrdinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } - // fill the partition key for the given MutableRow Object + + // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) - hivePartitionRDD.mapPartitions { iter => + createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) } }.toSeq @@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors { } /** - * Transform the raw data(Writable object) into the Row object for an iterable input - * @param iter Iterable input which represented as Writable object - * @param deserializer Deserializer associated with the input writable object - * @param attrs Represents the row attribute names and its zero-based position in the MutableRow - * @param row reusable MutableRow object - * - * @return Iterable Row object that transformed from the given iterable input. + * Transform all given raw `Writable`s into `Row`s. + * + * @param iterator Iterator of all `Writable`s to be transformed + * @param deserializer The `Deserializer` associated with the input `Writable` + * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding + * positions in the output schema + * @param mutableRow A reusable `MutableRow` that should be filled + * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( - iter: Iterator[Writable], + iterator: Iterator[Writable], deserializer: Deserializer, - attrs: Seq[(Attribute, Int)], - row: GenericMutableRow): Iterator[Row] = { + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - // get the field references according to the attributes(output of the reader) required - val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => + soi.getStructFieldRef(attr.name) -> ordinal + }.unzip + + // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern + // matching and branching costs per row. + val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi) + } + } // Map each tuple to a row object - iter.map { value => + iterator.map { value => val raw = deserializer.deserialize(value) - var idx = 0; - while (idx < fieldRefs.length) { - val fieldRef = fieldRefs(idx)._1 - val fieldIdx = fieldRefs(idx)._2 - val fieldValue = soi.getStructFieldData(raw, fieldRef) - - row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) - - idx += 1 + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - row: Row + mutableRow: Row } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6bf8d18a5c32c..8c8a8b124ac69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = sql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + + val expected = sql("SELECT key FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + assert(actual === expected) } @@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { - case Row(key: String, value: String) => key -> value + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet clear() From 0f8c4edf4e750e3d11da27cc22c40b0489da7f37 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 13 Sep 2014 16:08:04 -0700 Subject: [PATCH 20/36] [SQL] Decrease partitions when testing Author: Michael Armbrust Closes #2164 from marmbrus/shufflePartitions and squashes the following commits: 0da1e8c [Michael Armbrust] test hax ef2d985 [Michael Armbrust] more test hacks. 2dabae3 [Michael Armbrust] more test fixes 0bdbf21 [Michael Armbrust] Make parquet tests less order dependent b42eeab [Michael Armbrust] increase test parallelism 80453d5 [Michael Armbrust] Decrease partitions when testing --- .../spark/sql/test/TestSQLContext.scala | 9 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 142 +++++------------- .../org/apache/spark/sql/hive/TestHive.scala | 7 +- 3 files changed, 51 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f2389f8f0591e..265b67737c475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,13 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b0a06cd3ca090..08f7358446b29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -58,8 +58,7 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean, - binaryField: Array[Byte]) + booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -70,13 +69,14 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, - binaryField: Array[Byte], array: Seq[Int], arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], data: Data) +case class BinaryData(binaryData: Array[Byte]) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -108,26 +108,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray)) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - } + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } - test("Treat binary as string") { + test("read/write binary data") { + // Since equality for Array[Byte] is broken we test this separately. + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) + parquetFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + } + + ignore("Treat binary as string") { val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString // Create the test file. @@ -142,37 +142,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA StructField("c2", BinaryType, false) :: Nil) val schemaRDD1 = applySchema(rowRDD, schema) schemaRDD1.saveAsParquetFile(path) - val resultWithBinary = parquetFile(path).collect - range.foreach { - i => - assert(resultWithBinary(i).getInt(0) === i) - assert(resultWithBinary(i)(1) === s"val_$i".getBytes) - } - - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - // This ParquetRelation always use Parquet types to derive output. - val parquetRelation = new ParquetRelation( - path.toString, - Some(TestSQLContext.sparkContext.hadoopConfiguration), - TestSQLContext) { - override val output = - ParquetTypesConverter.convertToAttributes( - ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, - TestSQLContext.isParquetBinaryAsString) - } - val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) - val resultWithString = schemaRDD.collect - range.foreach { - i => - assert(resultWithString(i).getInt(0) === i) - assert(resultWithString(i)(1) === s"val_$i") - } + checkAnswer( + parquetFile(path).select('c1, 'c2.cast(StringType)), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - schemaRDD.registerTempTable("tmp") + setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + parquetFile(path).printSchema() checkAnswer( - sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + parquetFile(path), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) + // Set it back. TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) @@ -275,34 +254,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) + val data = sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), Data((0 until x), Nested(x, s"$x")))) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) - assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) - assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) - assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) - } + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } test("self-join parquet files") { @@ -399,23 +363,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } - test("Saving case class RDD table to file and reading it back in") { - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") - } - Utils.deleteRecursively(file) - } - test("Read a parquet file instead of a directory") { val file = getTempFilePath("parquet") val path = file.toString @@ -448,32 +395,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) - assert(rdd_copy1(0).apply(0) === 1) - assert(rdd_copy1(0).apply(1) === "val_1") - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) assert(rdd_copy2.size === 200) - assert(rdd_copy2(0).apply(0) === 1) - assert(rdd_copy2(0).apply(1) === "val_1") - assert(rdd_copy2(99).apply(0) === 100) - assert(rdd_copy2(99).apply(1) === "val_100") - assert(rdd_copy2(100).apply(0) === 1) - assert(rdd_copy2(100).apply(1) === "val_1") Utils.deleteRecursively(dirname) } test("Insert (appending) to same table via Scala API") { - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) assert(double_rdd.size === 30) - for(i <- (0 to 14)) { - assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") - } + // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) ParquetTestData.writeFile() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a3bfd3a8f1fd2..70fb15259e7d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ object TestHive - extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) /** * A locally running test instance of Spark's Hive execution engine. @@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists From 2aea0da84c58a179917311290083456dfa043db7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 13 Sep 2014 16:22:04 -0700 Subject: [PATCH 21/36] [SPARK-3030] [PySpark] Reuse Python worker Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts. This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming. For a job with broadcast (43M after compress): ``` b = sc.broadcast(set(range(30000000))) print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count() ``` It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks. It's enabled by default, could be disabled by `spark.python.worker.reuse = false`. Author: Davies Liu Closes #2259 from davies/reuse-worker and squashes the following commits: f11f617 [Davies Liu] Merge branch 'master' into reuse-worker 3939f20 [Davies Liu] fix bug in serializer in mllib cf1c55e [Davies Liu] address comments 3133a60 [Davies Liu] fix accumulator with reused worker 760ab1f [Davies Liu] do not reuse worker if there are any exceptions 7abb224 [Davies Liu] refactor: sychronized with itself ac3206e [Davies Liu] renaming 8911f44 [Davies Liu] synchronized getWorkerBroadcasts() 6325fc1 [Davies Liu] bugfix: bid >= 0 e0131a2 [Davies Liu] fix name of config 583716e [Davies Liu] only reuse completed and not interrupted worker ace2917 [Davies Liu] kill python worker after timeout 6123d0f [Davies Liu] track broadcasts for each worker 8d2f08c [Davies Liu] reuse python worker --- .../scala/org/apache/spark/SparkEnv.scala | 8 ++ .../apache/spark/api/python/PythonRDD.scala | 58 ++++++++++--- .../api/python/PythonWorkerFactory.scala | 85 ++++++++++++++++--- docs/configuration.md | 10 +++ python/pyspark/daemon.py | 38 ++++----- python/pyspark/mllib/_common.py | 12 ++- python/pyspark/serializers.py | 4 + python/pyspark/tests.py | 35 ++++++++ python/pyspark/worker.py | 9 +- 9 files changed, 208 insertions(+), 51 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd95e406f2a8e..009ed64775844 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -108,6 +108,14 @@ class SparkEnv ( pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } + + private[spark] + def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } } object SparkEnv extends Logging { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ae8010300a500..ca8eef5f99edf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Try, Success, Failure} @@ -52,6 +53,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions = parent.partitions @@ -63,19 +65,26 @@ private[spark] class PythonRDD( val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars += ("SPARK_REUSE_WORKER" -> "1") + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) + var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - - // Cleanup the worker socket. This will also cause the Python worker to exit. - try { - worker.close() - } catch { - case e: Exception => logWarning("Failed to close worker socket", e) + if (reuse_worker && complete_cleanly) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + } else { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } } } @@ -133,6 +142,7 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + complete_cleanly = true null } } catch { @@ -195,11 +205,26 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - dataOut.writeInt(broadcastVars.length) + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- oldBids) { + if (!newBids.contains(bid)) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + oldBids.add(broadcast.id) + } } dataOut.flush() // Serialized command: @@ -207,17 +232,18 @@ private[spark] class PythonRDD( dataOut.write(command) // Data values PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) + worker.shutdownOutput() case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - } finally { - Try(worker.shutdownOutput()) // kill Python worker process + worker.shutdownOutput() } } } @@ -278,6 +304,14 @@ private object SpecialLengths { private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") + // remember the broadcasts sent to each worker + private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private def getWorkerBroadcasts(worker: Socket) = { + synchronized { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } + } + /** * Adapter for calling SparkContext#runJob from Python. * diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4c4796f6c59ba..71bdf0fe1b917 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 - var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val idleWorkers = new mutable.Queue[Socket]() + var lastActivity = 0L + new MonitorThread().start() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { + synchronized { + if (idleWorkers.size > 0) { + return idleWorkers.dequeue() + } + } createThroughDaemon() } else { createSimpleWorker() @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Monitor all the idle workers, kill them after timeout. + */ + private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + while (true) { + synchronized { + if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { + cleanupIdleWorkers() + lastActivity = System.currentTimeMillis() + } + } + Thread.sleep(10000) + } + } + } + + private def cleanupIdleWorkers() { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { + cleanupIdleWorkers() + // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stopWorker(worker: Socket) { - if (useDaemon) { - if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => - // tell daemon to kill worker by pid - val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) - output.flush() - daemon.getOutputStream.flush() + synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) } - } else { - simpleWorkers.get(worker).foreach(_.destroy()) } worker.close() } + + def releaseWorker(worker: Socket) { + if (useDaemon) { + synchronized { + lastActivity = System.currentTimeMillis() + idleWorkers.enqueue(worker) + } + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } } private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute } diff --git a/docs/configuration.md b/docs/configuration.md index 36178efb97103..af16489a44281 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + spark.python.worker.reuse + true + + Reuse Python worker or not. If yes, it will use a fixed number of Python workers, + does not need to fork() a Python process for every tasks. It will be very useful + if there is large broadcast, then the broadcast will not be needed to transfered + from JVM to Python worker for every task. + + spark.executorEnv.[EnvironmentVariableName] (none) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 15445abf67147..64d6202acb27d 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -23,6 +23,7 @@ import sys import traceback import time +import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -46,17 +47,6 @@ def worker(sock): signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) - # Blocks until the socket is closed by draining the input stream - # until it raises an exception or returns EOF. - def waitSocketClose(sock): - try: - while True: - # Empty string is returned upon EOF (and only then). - if sock.recv(4096) == '': - return - except: - pass - # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. @@ -64,17 +54,13 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - # Acknowledge that the fork was successful - write_int(os.getpid(), outfile) - outfile.flush() worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) + if exit_code: + os._exit(exit_code) # Cleanup zombie children @@ -111,6 +97,8 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + reuse = os.environ.get("SPARK_REUSE_WORKER") + # Initialization complete try: while True: @@ -163,7 +151,19 @@ def handle_sigterm(*args): # in child process listen_sock.close() try: - worker(sock) + # Acknowledge that the fork was successful + outfile = sock.makefile("w") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + while True: + worker(sock) + if not reuse: + # wait for closing + while sock.recv(1024): + pass + break + gc.collect() except: traceback.print_exc() os._exit(1) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index bb60d3d0c8463..68f6033616726 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -21,7 +21,7 @@ from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import Serializer +from pyspark.serializers import FramedSerializer """ @@ -451,18 +451,16 @@ def _serialize_rating(r): return ba -class RatingDeserializer(Serializer): +class RatingDeserializer(FramedSerializer): - def loads(self, stream): - length = struct.unpack("!i", stream.read(4))[0] - ba = stream.read(length) - res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4) + def loads(self, string): + res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) return int(res[0]), int(res[1]), res[2] def load_stream(self, stream): while True: try: - yield self.loads(stream) + yield self._read_with_length(stream) except struct.error: return except EOFError: diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a5f9341e819a9..ec3c6f055441d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream): def _read_with_length(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError obj = stream.read(length) if obj == "": raise EOFError @@ -438,6 +440,8 @@ def __init__(self, use_unicode=False): def loads(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b687d695b01c4..747cd1767de7b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1222,11 +1222,46 @@ def run(): except OSError: self.fail("daemon had been killed") + # run a normal job + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + def test_fd_leak(self): N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(range(100), 1) + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write("Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + class TestSparkSubmit(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6805063e06798..61b8a74d060e8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -69,9 +69,14 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + if bid >= 0: + value = ser._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + else: + bid = - bid - 1 + _broadcastRegistry.remove(bid) + _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() From 4e3fbe8cdb6c6291e219195abb272f3c81f0ed63 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 13 Sep 2014 22:31:21 -0700 Subject: [PATCH 22/36] [SPARK-3463] [PySpark] aggregate and show spilled bytes in Python Aggregate the number of bytes spilled into disks during aggregation or sorting, show them in Web UI. ![spilled](https://cloud.githubusercontent.com/assets/40902/4209758/4b995562-386d-11e4-97c1-8e838ee1d4e3.png) This patch is blocked by SPARK-3465. (It includes a fix for that). Author: Davies Liu Closes #2336 from davies/metrics and squashes the following commits: e37df38 [Davies Liu] remove outdated comments 1245eb7 [Davies Liu] remove the temporary fix ebd2f43 [Davies Liu] Merge branch 'master' into metrics 7e4ad04 [Davies Liu] Merge branch 'master' into metrics fbe9029 [Davies Liu] show spilled bytes in Python in web ui --- .../apache/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/shuffle.py | 19 ++++++++++++++++--- python/pyspark/tests.py | 15 ++++++++------- python/pyspark/worker.py | 14 ++++++++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ca8eef5f99edf..d5002fa02992b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -124,6 +124,10 @@ private[spark] class PythonRDD( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 49829f5280a5f..ce597cbe91e15 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -68,6 +68,11 @@ def _get_local_dirs(sub): return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] +# global stats +MemoryBytesSpilled = 0L +DiskBytesSpilled = 0L + + class Aggregator(object): """ @@ -313,10 +318,12 @@ def _spill(self): It will dump the data in batch for better performance. """ + global MemoryBytesSpilled, DiskBytesSpilled path = self._get_spill_dir(self.spills) if not os.path.exists(path): os.makedirs(path) + used_memory = get_used_memory() if not self.pdata: # The data has not been partitioned, it will iterator the # dataset once, write them into different files, has no @@ -334,6 +341,7 @@ def _spill(self): self.serializer.dump_stream([(k, v)], streams[h]) for s in streams: + DiskBytesSpilled += s.tell() s.close() self.data.clear() @@ -346,9 +354,11 @@ def _spill(self): # dump items in batch self.serializer.dump_stream(self.pdata[i].iteritems(), f) self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) self.spills += 1 gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 def iteritems(self): """ Return all merged items as iterator """ @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) - self._spilled_bytes = 0 def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): Sort the elements in iterator, do external sort when the memory goes above the limit. """ + global MemoryBytesSpilled, DiskBytesSpilled batch = 10 chunks, current_chunk = [], [] iterator = iter(iterator) @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False): if len(chunk) < batch: break - if get_used_memory() > self.memory_limit: + used_memory = get_used_memory() + if used_memory > self.memory_limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) with open(path, 'w') as f: self.serializer.dump_stream(current_chunk, f) - self._spilled_bytes += os.path.getsize(path) chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] + gc.collect() + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + DiskBytesSpilled += os.path.getsize(path) elif not chunks: batch = min(batch * 2, 10000) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 747cd1767de7b..f3309a20fcffb 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -46,6 +46,7 @@ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType +from pyspark import shuffle _have_scipy = False _have_numpy = False @@ -138,17 +139,17 @@ def test_external_sort(self): random.shuffle(l) sorter = ExternalSorter(1) self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertGreater(sorter._spilled_bytes, 0) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) + self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 61b8a74d060e8..252176ac65fec 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,16 +23,14 @@ import time import socket import traceback -# CloudPickler needs to be imported so that depicklers are registered using the -# copy_reg module. + from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ CompressedSerializer - +from pyspark import shuffle pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -52,6 +50,11 @@ def main(infile, outfile): if split_index == -1: # for unit tests return + # initialize global state + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + _accumulatorRegistry.clear() + # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir @@ -97,6 +100,9 @@ def main(infile, outfile): exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) From c243b21a8ba2610266702e00d7d4b5443cb1f687 Mon Sep 17 00:00:00 2001 From: Bertrand Bossy Date: Sun, 14 Sep 2014 21:10:17 -0700 Subject: [PATCH 23/36] SPARK-3039: Allow spark to be built using avro-mapred for hadoop2 SPARK-3039: Adds the maven property "avro.mapred.classifier" to build spark-assembly with avro-mapred with support for the new Hadoop API. Sets this property to hadoop2 for Hadoop 2 profiles. I am not very familiar with maven, nor do I know whether this potentially breaks something in the hive part of spark. There might be a more elegant way of doing this. Author: Bertrand Bossy Closes #1945 from bbossy/SPARK-3039 and squashes the following commits: c32ce59 [Bertrand Bossy] SPARK-3039: Allow spark to be built using avro-mapred for hadoop2 --- pom.xml | 5 +++++ sql/hive/pom.xml | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/pom.xml b/pom.xml index 28763476f8313..520aed3806937 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,7 @@ 0.3.6 3.0.0 1.7.6 + 0.7.1 1.8.3 1.1.0 @@ -621,6 +622,7 @@ org.apache.avro avro-mapred ${avro.version} + ${avro.mapred.classifier} io.netty @@ -1108,6 +1110,7 @@ 2.2.0 2.5.0 + hadoop2 @@ -1117,6 +1120,7 @@ 2.3.0 2.5.0 0.9.0 + hadoop2 @@ -1126,6 +1130,7 @@ 2.4.0 2.5.0 0.9.0 + hadoop2 diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 45a4c6dc98da0..9d7a02bf7b0b7 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -95,6 +95,15 @@ org.apache.avro avro + ${avro.version} + + + + org.apache.avro + avro-mapred + ${avro.version} + ${avro.mapred.classifier} org.scalatest From f493f7982b50e3c99e78b649e7c6c5b4313c5ffa Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sun, 14 Sep 2014 21:17:29 -0700 Subject: [PATCH 24/36] [SPARK-3452] Maven build should skip publishing artifacts people shouldn... ...'t depend on Publish local in maven term is `install` and publish otherwise is `deploy` So disabled both for following projects. Author: Prashant Sharma Closes #2329 from ScrapCodes/SPARK-3452/maven-skip-install and squashes the following commits: 257b79a [Prashant Sharma] [SPARK-3452] Maven build should skip publishing artifacts people shouldn't depend on --- assembly/pom.xml | 14 ++++++++++++++ examples/pom.xml | 14 ++++++++++++++ extras/java8-tests/pom.xml | 14 ++++++++++++++ repl/pom.xml | 14 ++++++++++++++ tools/pom.xml | 14 ++++++++++++++ yarn/pom.xml | 14 ++++++++++++++ 6 files changed, 84 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index 4146168fc804b..604b1ab3de6a8 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -88,6 +88,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins diff --git a/examples/pom.xml b/examples/pom.xml index 3f46c40464d3b..2b561857f9f33 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -203,6 +203,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-shade-plugin diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 8658ecf5abfab..7e478bed62da7 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -74,6 +74,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-surefire-plugin diff --git a/repl/pom.xml b/repl/pom.xml index fcc5f90d870e8..af528c8914335 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -99,6 +99,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.scalatest scalatest-maven-plugin diff --git a/tools/pom.xml b/tools/pom.xml index f36674476770c..b90eb0ca250c5 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -63,6 +63,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-source-plugin diff --git a/yarn/pom.xml b/yarn/pom.xml index 7fcd7ee0d4547..815a736c2e8fd 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -88,6 +88,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.codehaus.mojo build-helper-maven-plugin From cc14644460872efb344e8d895859d70213a40840 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 15 Sep 2014 08:53:58 -0500 Subject: [PATCH 25/36] [SPARK-3410] The priority of shutdownhook for ApplicationMaster should not be integer literal I think, it need to keep the priority of shutdown hook for ApplicationMaster than the priority of shutdown hook for o.a.h.FileSystem depending on changing the priority for FileSystem. Author: Kousuke Saruta Closes #2283 from sarutak/SPARK-3410 and squashes the following commits: 1d44fef [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3410 bd6cc53 [Kousuke Saruta] Modified style ee6f1aa [Kousuke Saruta] Added constant "SHUTDOWN_HOOK_PRIORITY" to ApplicationMaster 54eb68f [Kousuke Saruta] Changed Shutdown hook priority to 20 2f0aee3 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3410 4c5cb93 [Kousuke Saruta] Modified the priority for AM's shutdown hook 217d1a4 [Kousuke Saruta] Removed unused import statements 717aba2 [Kousuke Saruta] Modified ApplicationMaster to make to keep the priority of shutdown hook for ApplicationMaster higher than the priority of shutdown hook for HDFS --- .../spark/deploy/yarn/ApplicationMaster.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 735d7723b0ce6..cde5fff637a39 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -21,12 +21,8 @@ import java.io.IOException import java.net.Socket import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConversions._ -import scala.util.Try - import akka.actor._ import akka.remote._ -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ @@ -107,8 +103,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } } - // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. - ShutdownHookManager.get().addShutdownHook(cleanupHook, 30) + + // Use higher priority than FileSystem. + assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) + ShutdownHookManager + .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) // Call this to force generation of secret so it gets populated into the // Hadoop UGI. This has to happen before the startUserClass which does a @@ -407,6 +406,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, object ApplicationMaster extends Logging { + val SHUTDOWN_HOOK_PRIORITY: Int = 30 + private var master: ApplicationMaster = _ def main(args: Array[String]) = { From fe2b1d6a209db9fe96b1c6630677955b94bd48c9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 15 Sep 2014 10:57:53 -0700 Subject: [PATCH 26/36] [SPARK-3425] do not set MaxPermSize for OpenJDK 1.8 Closes #2387 Author: Matthew Farrellee Closes #2301 from mattf/SPARK-3425 and squashes the following commits: 20f3c09 [Matthew Farrellee] [SPARK-3425] do not set MaxPermSize for OpenJDK 1.8 --- bin/spark-class | 2 +- dev/merge_spark_pr.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index 5f5f9ea74888d..613dc9c4566f2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index d48c8bde12905..a8e92e36fe0d8 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -44,9 +44,9 @@ # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "35500") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" From e59fac1f97c3fbeeb6defd12625a49763a353156 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 15 Sep 2014 16:11:41 -0700 Subject: [PATCH 27/36] [SPARK-3518] Remove wasted statement in JsonProtocol Author: Kousuke Saruta Closes #2380 from sarutak/SPARK-3518 and squashes the following commits: 8a1464e [Kousuke Saruta] Replaced a variable with simple field reference c660fbc [Kousuke Saruta] Removed useless statement in JsonProtocol.scala --- core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b0754e3ce10db..c4dddb2d1037e 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -205,7 +205,6 @@ private[spark] object JsonProtocol { } def taskInfoToJson(taskInfo: TaskInfo): JValue = { - val accumUpdateMap = taskInfo.accumulables ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ ("Attempt" -> taskInfo.attempt) ~ From 37d925280cdfdda8f6f7174c67a614056eea5d69 Mon Sep 17 00:00:00 2001 From: yantangzhai Date: Mon, 15 Sep 2014 16:57:38 -0700 Subject: [PATCH 28/36] [SPARK-2714] DAGScheduler logs jobid when runJob finishes DAGScheduler logs jobid when runJob finishes Author: yantangzhai Closes #1617 from YanTangZhai/SPARK-2714 and squashes the following commits: 0a0243f [yantangzhai] [SPARK-2714] DAGScheduler logs jobid when runJob finishes fbb1150 [yantangzhai] [SPARK-2714] DAGScheduler logs jobid when runJob finishes 7aec2a9 [yantangzhai] [SPARK-2714] DAGScheduler logs jobid when runJob finishes fb42f0f [yantangzhai] [SPARK-2714] DAGScheduler logs jobid when runJob finishes 090d908 [yantangzhai] [SPARK-2714] DAGScheduler logs jobid when runJob finishes --- core/src/main/scala/org/apache/spark/SparkContext.scala | 3 --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 9 +++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 218b353dd9d49..428f019b02a23 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1072,11 +1072,8 @@ class SparkContext(config: SparkConf) extends Logging { val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) - val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo( - "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6fcf9e31543ed..b2774dfc47553 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -507,11 +507,16 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { + val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => {} + case JobSucceeded => { + logInfo("Job %d finished: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + } case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite.shortForm) + logInfo("Job %d failed: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) throw exception } } From 3b93128139e8d303f1d7bfd04e9a99a11a5b6404 Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Mon, 15 Sep 2014 17:39:31 -0700 Subject: [PATCH 29/36] [SPARK-3396][MLLIB] Use SquaredL2Updater in LogisticRegressionWithSGD SimpleUpdater ignores the regularizer, which leads to an unregularized LogReg. To enable the common L2 regularizer (and the corresponding regularization parameter) for logistic regression the SquaredL2Updater has to be used in SGD (see, e.g., [SVMWithSGD]) Author: Christoph Sawade Closes #2398 from BigCrunsh/fix-regparam-logreg and squashes the following commits: 0820c04 [Christoph Sawade] Use SquaredL2Updater in LogisticRegressionWithSGD --- .../classification/LogisticRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 44 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 486bdbfa9cb47..84d3c7cebd7c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private ( extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() + private val updater = new SquaredL2Updater() override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 862178694a50e..e954baaf7d91e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -43,7 +43,7 @@ object LogisticRegressionSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -58,12 +58,15 @@ object LogisticRegressionSuite { } class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction( + predictions: Seq[Double], + input: Seq[LabeledPoint], + expectedAcc: Double = 0.83) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } // At least 83% of the predictions should be on. - ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 + ((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + test("logistic regression with initial weights and non-default regularization parameter") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithSGD().setIntercept(true) + lr.optimizer. + setStepSize(10.0). + setNumIterations(10). + setRegParam(1.0) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -430000.0 relTol 20000.0) + assert(model.intercept ~== 370000.0 relTol 20000.0) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8) + } + test("logistic regression with initial weights with LBFGS") { val nPoints = 10000 val A = 2.0 From 983d6a9c48b69c5f0542922aa8b133f69eb1034d Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 15 Sep 2014 17:41:15 -0700 Subject: [PATCH 30/36] [MLlib] Update SVD documentation in IndexedRowMatrix Updating this to reflect the newest SVD via ARPACK Author: Reza Zadeh Closes #2389 from rezazadeh/irmdocs and squashes the following commits: 7fa1313 [Reza Zadeh] Update svd docs 715da25 [Reza Zadeh] Updated computeSVD documentation IndexedRowMatrix --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index ac6eaea3f43ad..5c1acca0ec532 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -76,16 +76,12 @@ class IndexedRowMatrix( } /** - * Computes the singular value decomposition of this matrix. + * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. * - * There is no restriction on m, but we require `n^2` doubles to fit in memory. - * Further, n should be less than m. - - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), from which we recover S and V. - * Then we compute U via easy matrix multiplication as U = A * (V * S^-1). - * Note that this approach requires `O(n^3)` time on the master node. + * The cost and implementation of this method is identical to that in + * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] + * With the addition of indices. * * At most k largest non-zero singular values and associated vectors are returned. * If there are k such values, then the dimensions of the return will be: From fdb302f49c021227026909bdcdade7496059013f Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Mon, 15 Sep 2014 17:43:26 -0700 Subject: [PATCH 31/36] [SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API Added minInstancesPerNode, minInfoGain params to: * DecisionTreeRunner.scala example * Python API (tree.py) Also: * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes CC: mengxr Author: qiping.lqp Author: Joseph K. Bradley Author: chouqin Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits: 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree --- .../examples/mllib/DecisionTreeRunner.scala | 13 ++++++++++++- .../spark/mllib/api/python/PythonMLLibAPI.scala | 8 ++++++-- .../apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- .../mllib/tree/configuration/Strategy.scala | 2 ++ .../apache/spark/mllib/tree/model/Predict.scala | 6 +----- .../spark/mllib/tree/DecisionTreeSuite.scala | 4 ++-- python/pyspark/mllib/tree.py | 16 ++++++++++++---- 7 files changed, 37 insertions(+), 16 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 72c3ab475b61f..4683e6eb966be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -55,6 +55,8 @@ object DecisionTreeRunner { maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, fracTest: Double = 0.2) def main(args: Array[String]) { @@ -75,6 +77,13 @@ object DecisionTreeRunner { opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -179,7 +188,9 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = numClasses) + numClassesForClassification = numClasses, + minInstancesPerNode = params.minInstancesPerNode, + minInfoGain = params.minInfoGain) val model = DecisionTree.train(training, strategy) println(model) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 4343124f102a0..fa0fa69f38634 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfoJMap: java.util.Map[Int, Int], impurityStr: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = { + maxBins: Int, + minInstancesPerNode: Int, + minInfoGain: Double): DecisionTreeModel = { val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + minInstancesPerNode = minInstancesPerNode, + minInfoGain = minInfoGain) DecisionTree.train(data, strategy) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 56bb8812100a7..c7f2576c822b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging { var groupIndex = 0 var doneTraining = true while (groupIndex < numGroups) { - val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, + val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer, numGroups, groupIndex) doneTraining = doneTraining && doneTrainingGroup groupIndex += 1 @@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - require(predict.isDefined, "must calculate predict for each node") + assert(predict.isDefined, "must calculate predict for each node") (bestSplit, bestSplitStats, predict.get) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 31d1e8ac30eea..caaccbfb8ad16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -77,6 +77,8 @@ class Strategy ( } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") val isMulticlassClassification = algo == Classification && numClassesForClassification > 2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 6fac2be2797bc..d8476b5cd7bc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,18 +17,14 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi - /** - * :: DeveloperApi :: * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -@DeveloperApi private[tree] class Predict( val predict: Double, - val prob: Double = 0.0) extends Serializable{ + val prob: Double = 0.0) extends Serializable { override def toString = { "predict = %f, prob = %f".format(predict, prob) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1bd7ea05c46c8..2b2e579b992f6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("don't choose split that doesn't satisfy min instance per node requirements") { - // if a split doesn't satisfy min instances per node requirements, + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index ccc000ac70ba6..5b13ab682bbfc 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -138,7 +138,8 @@ class DecisionTree(object): @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32): + impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for classification. @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "classification", numClasses, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) @staticmethod def trainRegressor(data, categoricalFeaturesInfo, - impurity="variance", maxDepth=5, maxBins=32): + impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for regression. @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "regression", 0, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) From da33acb8b681eca5e787d546fe922af76a151398 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 15 Sep 2014 18:57:25 -0700 Subject: [PATCH 32/36] [SPARK-2951] [PySpark] support unpickle array.array for Python 2.6 Pyrolite can not unpickle array.array which pickled by Python 2.6, this patch fix it by extend Pyrolite. There is a bug in Pyrolite when unpickle array of float/double, this patch workaround it by reverse the endianness for float/double. This workaround should be removed after Pyrolite have a new release to fix this issue. I had send an PR to Pyrolite to fix it: https://github.com/irmen/Pyrolite/pull/11 Author: Davies Liu Closes #2365 from davies/pickle and squashes the following commits: f44f771 [Davies Liu] enable tests about array 3908f5c [Davies Liu] Merge branch 'master' into pickle c77c87b [Davies Liu] cleanup debugging code 60e4e2f [Davies Liu] support unpickle array.array for Python 2.6 --- .../apache/spark/api/python/SerDeUtil.scala | 51 +++++++++++++++++++ python/pyspark/context.py | 1 + python/pyspark/tests.py | 2 - 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index efc9009c088a8..6668797f5f8be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.python +import java.nio.ByteOrder + import scala.collection.JavaConversions._ import scala.util.Failure import scala.util.Try @@ -28,6 +30,55 @@ import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[python] object SerDeUtil extends Logging { + // Unpickle array.array generated by Python 2.6 + class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { + // /* Description of types */ + // static struct arraydescr descriptors[] = { + // {'c', sizeof(char), c_getitem, c_setitem}, + // {'b', sizeof(char), b_getitem, b_setitem}, + // {'B', sizeof(char), BB_getitem, BB_setitem}, + // #ifdef Py_USING_UNICODE + // {'u', sizeof(Py_UNICODE), u_getitem, u_setitem}, + // #endif + // {'h', sizeof(short), h_getitem, h_setitem}, + // {'H', sizeof(short), HH_getitem, HH_setitem}, + // {'i', sizeof(int), i_getitem, i_setitem}, + // {'I', sizeof(int), II_getitem, II_setitem}, + // {'l', sizeof(long), l_getitem, l_setitem}, + // {'L', sizeof(long), LL_getitem, LL_setitem}, + // {'f', sizeof(float), f_getitem, f_setitem}, + // {'d', sizeof(double), d_getitem, d_setitem}, + // {'\0', 0, 0, 0} /* Sentinel */ + // }; + // TODO: support Py_UNICODE with 2 bytes + // FIXME: unpickle array of float is wrong in Pyrolite, so we reverse the + // machine code for float/double here to workaround it. + // we should fix this after Pyrolite fix them + val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, + 'L' -> 11, 'l' -> 13, 'f' -> 14, 'd' -> 16, 'u' -> 21 + ) + } else { + Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, + 'L' -> 10, 'l' -> 12, 'f' -> 15, 'd' -> 17, 'u' -> 20 + ) + } + override def construct(args: Array[Object]): Object = { + if (args.length == 1) { + construct(args ++ Array("")) + } else if (args.length == 2 && args(1).isInstanceOf[String]) { + val typecode = args(0).asInstanceOf[String].charAt(0) + val data: String = args(1).asInstanceOf[String] + construct(typecode, machineCodes(typecode), data.getBytes("ISO-8859-1")) + } else { + super.construct(args) + } + } + } + + def initialize() = { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + } private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3ab98e262df31..ea28e8cd8c89f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -214,6 +214,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile + SparkContext._jvm.SerDeUtil.initialize() if instance: if (SparkContext._active_spark_context and diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f3309a20fcffb..f255b44359fec 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -956,8 +956,6 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) - @unittest.skipIf(sys.version_info[:2] <= (2, 6) or python_implementation() == "PyPy", - "Skipped on 2.6 and PyPy until SPARK-2951 is fixed") def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays From 60050f42885582a699fc7a6fa0529964162bb8a3 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Mon, 15 Sep 2014 19:28:17 -0700 Subject: [PATCH 33/36] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. Also made some cosmetic cleanups. Author: Aaron Staple Closes #2385 from staple/SPARK-1087 and squashes the following commits: 7b3bb13 [Aaron Staple] Address review comments, cosmetic cleanups. 10ba6e1 [Aaron Staple] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. --- python/pyspark/context.py | 8 +--- python/pyspark/rdd.py | 58 ++--------------------- python/pyspark/traceback_utils.py | 78 +++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 61 deletions(-) create mode 100644 python/pyspark/traceback_utils.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ea28e8cd8c89f..a33aae87f65e8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -from collections import namedtuple from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD +from pyspark.traceback_utils import CallSite, first_spark_call from py4j.java_collections import ListConverter @@ -99,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - if rdd._extract_concise_traceback() is not None: - self._callsite = rdd._extract_concise_traceback() - else: - tempNamedTuple = namedtuple("Callsite", "function file linenum") - self._callsite = tempNamedTuple(function=None, file=None, linenum=None) + self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6ad5ab2a2d1ae..21f182b0ff137 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,13 +18,11 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from collections import namedtuple from itertools import chain, ifilter, imap import operator import os import sys import shlex -import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -45,6 +43,7 @@ from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter +from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -81,57 +80,6 @@ def portable_hash(x): return hash(x) -def _extract_concise_traceback(): - """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number - """ - tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") - if len(tb) == 0: - return None - file, line, module, what = tb[len(tb) - 1] - sparkpath = os.path.dirname(file) - first_spark_frame = len(tb) - 1 - for i in range(0, len(tb)): - file, line, fun, what = tb[i] - if file.startswith(sparkpath): - first_spark_frame = i - break - if first_spark_frame == 0: - file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) - sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) - -_spark_stack_depth = 0 - - -class _JavaStackTrace(object): - - def __init__(self, sc): - tb = _extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) - else: - self._traceback = "Error! Could not extract traceback info" - self._context = sc - - def __enter__(self): - global _spark_stack_depth - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - _spark_stack_depth += 1 - - def __exit__(self, type, value, tb): - global _spark_stack_depth - _spark_stack_depth -= 1 - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(None) - - class BoundedFloat(float): """ Bounded value is generated by approximate job, with confidence and low @@ -704,7 +652,7 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py new file mode 100644 index 0000000000000..bb8646df2b0bf --- /dev/null +++ b/python/pyspark/traceback_utils.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple +import os +import traceback + + +CallSite = namedtuple("CallSite", "function file linenum") + + +def first_spark_call(): + """ + Return a CallSite representing the first Spark call in the current call stack. + """ + tb = traceback.extract_stack() + if len(tb) == 0: + return None + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return CallSite(function=fun, file=file, linenum=line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] + return CallSite(function=sfun, file=ufile, linenum=uline) + + +class SCCallSiteSync(object): + """ + Helper for setting the spark context call site. + + Example usage: + from pyspark.context import SCCallSiteSync + with SCCallSiteSync() as css: + + """ + + _spark_stack_depth = 0 + + def __init__(self, sc): + call_site = first_spark_call() + if call_site is not None: + self._call_site = "%s at %s:%s" % ( + call_site.function, call_site.file, call_site.linenum) + else: + self._call_site = "Error! Could not extract traceback info" + self._context = sc + + def __enter__(self): + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._call_site) + SCCallSiteSync._spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + SCCallSiteSync._spark_stack_depth -= 1 + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(None) From d428ac6a221d2dce19c43442abf197f2ade6658f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 15 Sep 2014 21:09:58 -0700 Subject: [PATCH 34/36] [SPARK-3540] Add reboot-slaves functionality to the ec2 script Tested on a real cluster. Author: Reynold Xin Closes #2404 from rxin/ec2-reboot-slaves and squashes the following commits: 00a2dbd [Reynold Xin] Allow rebooting slaves. --- ec2/spark_ec2.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index bfd07593b92ed..5682e96aa8770 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -52,7 +52,7 @@ class UsageError(Exception): def parse_args(): parser = OptionParser( usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master", + + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", add_help_option=False) parser.add_option( "-h", "--help", action="help", @@ -950,6 +950,20 @@ def real_main(): subprocess.check_call( ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + elif action == "reboot-slaves": + response = raw_input( + "Are you sure you want to reboot the cluster " + + cluster_name + " slaves?\n" + + "Reboot cluster slaves " + cluster_name + " (y/N): ") + if response == "y": + (master_nodes, slave_nodes) = get_existing_cluster( + conn, opts, cluster_name, die_on_error=False) + print "Rebooting slaves..." + for inst in slave_nodes: + if inst.state not in ["shutting-down", "terminated"]: + print "Rebooting " + inst.id + inst.reboot() + elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) print master_nodes[0].public_dns_name From ecf0c02935815f0d4018c0e30ec4c784e60a5db0 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 15 Sep 2014 21:14:00 -0700 Subject: [PATCH 35/36] [SPARK-3433][BUILD] Fix for Mima false-positives with @DeveloperAPI and @Experimental annotations. Actually false positive reported was due to mima generator not picking up the new jars in presence of old jars(theoretically this should not have happened.). So as a workaround, ran them both separately and just append them together. Author: Prashant Sharma Author: Prashant Sharma Closes #2285 from ScrapCodes/mima-fix and squashes the following commits: 093c76f [Prashant Sharma] Update mima 59012a8 [Prashant Sharma] Update mima 35b6c71 [Prashant Sharma] SPARK-3433 Fix for Mima false-positives with @DeveloperAPI and @Experimental annotations. --- dev/mima | 8 ++++++++ project/MimaBuild.scala | 6 ++++++ project/MimaExcludes.scala | 8 +------- project/SparkBuild.scala | 2 +- .../apache/spark/tools/GenerateMIMAIgnore.scala | 14 ++++++++++---- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/dev/mima b/dev/mima index f9b9b03538f15..40603166c21ae 100755 --- a/dev/mima +++ b/dev/mima @@ -25,11 +25,19 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update +rm -f .generated-mima* + +# Generate Mima Ignore is called twice, first with latest built jars +# on the classpath and then again with previous version jars on the classpath. +# Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath +# it did not process the new classes (which are in assembly jar). +./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore + echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 0f5d71afcf616..39f8ba4745737 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -30,6 +30,12 @@ object MimaBuild { def excludeMember(fullName: String) = Seq( ProblemFilters.exclude[MissingMethodProblem](fullName), + // Sometimes excluded methods have default arguments and + // they are translated into public methods/fields($default$) in generated + // bytecode. It is not possible to exhustively list everything. + // But this should be okay. + ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), + ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), ProblemFilters.exclude[MissingFieldProblem](fullName), ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName), ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46b78bd5c7061..2f1e05dfcc7b1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,14 +37,8 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("graphx") - ) ++ - // This is @DeveloperAPI, but Mima still gives false-positives: - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - // This is @Experimental, but Mima still gives false-positives: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync") ) + case v if v.startsWith("1.1") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c07ea313f1228..ab9f8ba120e83 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -187,7 +187,7 @@ object OldDeps { Some("org.apache.spark" % fullId % "1.1.0") } - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.4", retrieveManaged := true, diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index bcf6d43ab34eb..595ded6ae67fa 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} +import scala.util.Try /** * A tool for generating classes to be excluded during binary checking with MIMA. It is expected @@ -121,12 +122,17 @@ object GenerateMIMAIgnore { } def main(args: Array[String]) { + import scala.tools.nsc.io.File val (privateClasses, privateMembers) = privateWithin("org.apache.spark") - scala.tools.nsc.io.File(".generated-mima-class-excludes"). - writeAll(privateClasses.mkString("\n")) + val previousContents = Try(File(".generated-mima-class-excludes").lines()). + getOrElse(Iterator.empty).mkString("\n") + File(".generated-mima-class-excludes") + .writeAll(previousContents + privateClasses.mkString("\n")) println("Created : .generated-mima-class-excludes in current directory.") - scala.tools.nsc.io.File(".generated-mima-member-excludes"). - writeAll(privateMembers.mkString("\n")) + val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) + .getOrElse(Iterator.empty).mkString("\n") + File(".generated-mima-member-excludes").writeAll(previousMembersContents + + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") } From febafefa5aaee3b3eda5e1b45a75bc6d8e7fb13f Mon Sep 17 00:00:00 2001 From: Ye Xianjin Date: Mon, 15 Sep 2014 21:53:38 -0700 Subject: [PATCH 36/36] [SPARK-3040] pick up a more proper local ip address for Utils.findLocalIpAddress method Short version: NetworkInterface.getNetworkInterfaces returns ifs in reverse order compared to ifconfig output. It may pick up ip address associated with tun0 or virtual network interface. See [SPARK_3040](https://issues.apache.org/jira/browse/SPARK-3040) for more detail Author: Ye Xianjin Closes #1946 from advancedxy/SPARK-3040 and squashes the following commits: f33f6b2 [Ye Xianjin] add windows support 087a785 [Ye Xianjin] reverse the Networkinterface.getNetworkInterfaces output order to get a more proper local ip address. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 79943766d0f0f..c76b7af18481d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -530,7 +530,12 @@ private[spark] object Utils extends Logging { if (address.isLoopbackAddress) { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces - for (ni <- NetworkInterface.getNetworkInterfaces) { + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable!