Skip to content

Commit

Permalink
Merge branch 'master' of github.com:apache/spark into unsafe_join
Browse files Browse the repository at this point in the history
Conflicts:
	sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
  • Loading branch information
Davies Liu committed Jul 21, 2015
2 parents 1a40f02 + df4ddb3 commit 69e38f5
Show file tree
Hide file tree
Showing 115 changed files with 3,018 additions and 1,424 deletions.
1 change: 1 addition & 0 deletions R/pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
'mllib.R'
'serialize.R'
'sparkR.R'
'utils.R'
4 changes: 4 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")

# MLlib integration
exportMethods("glm",
"predict")

# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })

#' @rdname glm
#' @export
setGeneric("glm")
73 changes: 73 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# mllib.R: Provides methods for MLlib integration

#' @title S4 class that represents a PipelineModel
#' @param model A Java object reference to the backing Scala PipelineModel
#' @export
setClass("PipelineModel", representation(model = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
#' @return a fitted MLlib model
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
#' df <- createDataFrame(sqlContext, iris)
#' model <- glm(Sepal_Length ~ Sepal_Width, df)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
alpha)
return(new("PipelineModel", model = model))
})

#' Make predictions from a model
#'
#' Makes predictions from a model produced by glm(), similarly to R's predict().
#'
#' @param model A fitted MLlib model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted values
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
13 changes: 8 additions & 5 deletions R/pkg/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
#' @param ... further arguments passed to or from other methods
print.structType <- function(x, ...) {
cat("StructType\n",
sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
"\", type = \"", field$dataType.toString(),
"\", nullable = ", field$nullable(), "\n",
sep = "") })
, sep = "")
sapply(x$fields(),
function(field) {
paste("|-", "name = \"", field$name(),
"\", type = \"", field$dataType.toString(),
"\", nullable = ", field$nullable(), "\n",
sep = "")
}),
sep = "")
}

#' structField
Expand Down
33 changes: 22 additions & 11 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else { # if node[[1]] is length of 1, check for some R special functions.
} else {
# if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
if (nodeChar == "{" || nodeChar == "(") {
# Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") { # Assignment Ops.
nodeChar == "<<-") {
# Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
Expand All @@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "function") { # Function definition.
} else if (nodeChar == "function") {
# Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "$") { # Skip the field.
} else if (nodeChar == "$") {
# Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
Expand All @@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
if (!nodeChar %in% defVars$data) {
# Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
Expand All @@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
if (!isNamespace(func.env) ||
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) {
# Only include SparkR internals.

# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
if (is.function(obj)) { # If the node is a function call.
if (is.function(obj)) {
# If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
})
if (sum(found) > 0) { # If function has been examined, ignore.
if (sum(found) > 0) {
# If function has been examined, ignore.
break
}
# Function has not been examined, record it and recursively clean its closure.
Expand Down Expand Up @@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
for (i in 1:(length(argNames) - 1)) {
# Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
Expand Down
42 changes: 42 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

library(testthat)

context("MLlib functions")

# Tests for MLlib functions in SparkR

sc <- sparkR.init()

sqlContext <- sparkRSQL.init(sc)

test_that("glm and predict", {
training <- createDataFrame(sqlContext, iris)
test <- select(training, "Sepal_Length")
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
prediction <- predict(model, test)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
})

test_that("predictions match with native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
})
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.10</artifactId>
<artifactId>jackson-module-scala_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.derby</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.Utils;

@Private
public class PrefixComparators {
Expand Down Expand Up @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareFloats(a, b);
}

public long computePrefix(float value) {
Expand All @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
return Utils.nanSafeCompareDoubles(a, b);
}

public long computePrefix(double value) {
Expand Down
62 changes: 44 additions & 18 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.{HashMap, HashSet, Map}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._

private[spark] sealed trait MapOutputTrackerMessage
Expand Down Expand Up @@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}

/**
* Called from executors to get the server URIs and output sizes of the map outputs of
* a given shuffle.
* Called from executors to get the server URIs and output sizes for each shuffle block that
* needs to be read from a given reduce task.
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
val startTime = System.currentTimeMillis

val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
Expand Down Expand Up @@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
}
logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " +
s"${System.currentTimeMillis - startTime} ms")

if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
Expand Down Expand Up @@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging {
}
}

// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
/**
* Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block
* manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that
* block manager.
*
* If any of the statuses is null (indicating a missing location due to a failed mapper),
* throws a FetchFailedException.
*
* @param shuffleId Identifier for the shuffle
* @param reduceId Identifier for the reduce task
* @param statuses List of map statuses, indexed by map ID.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
statuses.map {
status =>
if (status == null) {
logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
(status.location, status.getSizeForBlock(reduceId))
}
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage)
} else {
splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId)))
}
}

splitsByAddress.toSeq
}
}
Loading

0 comments on commit 69e38f5

Please sign in to comment.