Skip to content

Commit

Permalink
merged from upstream; resurrected accidentally deleted Utils class
Browse files Browse the repository at this point in the history
  • Loading branch information
reggert committed Nov 10, 2015
2 parents 93b3065 + c4e19b3 commit dcd2883
Show file tree
Hide file tree
Showing 120 changed files with 5,787 additions and 1,234 deletions.
6 changes: 3 additions & 3 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1944,9 +1944,9 @@ setMethod("describe",
#' @rdname summary
#' @name summary
setMethod("summary",
signature(x = "DataFrame"),
function(x) {
describe(x)
signature(object = "DataFrame"),
function(object, ...) {
describe(object)
})


Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") })

#' @rdname summary
#' @export
setGeneric("summary", function(x, ...) { standardGeneric("summary") })
setGeneric("summary", function(object, ...) { standardGeneric("summary") })

# @rdname tojson
# @export
Expand Down
30 changes: 22 additions & 8 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,28 @@ setMethod("predict", signature(object = "PipelineModel"),
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(x = "PipelineModel"),
function(x, ...) {
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", x@model)
"getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelCoefficients", x@model)
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
"getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", object@model)
devianceResiduals <- matrix(devianceResiduals, nrow = 1)
colnames(devianceResiduals) <- c("Min", "Max")
rownames(devianceResiduals) <- rep("", times = 1)
coefficients <- matrix(coefficients, ncol = 4)
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients))
} else {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
}
})
37 changes: 30 additions & 7 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", {

test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
coefs <- as.vector(stats$coefficients)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
coefs <- unlist(stats$Coefficients)
devianceResiduals <- unlist(stats$DevianceResiduals)

rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331)
rTValue <- c(7.123, 7.557, -13.644, -10.798)
rPValue <- c(0.0, 0.0, 0.0, 0.0)
rDevianceResiduals <- c(-0.95096, 0.72918)

expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6))
expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5))
expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3))
expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6))
expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
expect_true(all(
as.character(stats$features) ==
rownames(stats$Coefficients) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})

Expand All @@ -85,14 +96,26 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
training <- filter(df, df$Species != "setosa")
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = "binomial"))
coefs <- as.vector(stats$coefficients)
coefs <- as.vector(stats$Coefficients)

rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit"))))
rStdError <- c(3.0974, 0.5169, 0.8628)
rTValue <- c(-4.212, 3.680, 0.469)
rPValue <- c(0.000, 0.000, 0.639)

expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4))
expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4))
expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3))
expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3))
expect_true(all(
as.character(stats$features) ==
rownames(stats$Coefficients) ==
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
})

test_that("summary works on base GLM models", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,11 @@ test_that("sample on a DataFrame", {
sampled <- sample(df, FALSE, 1.0)
expect_equal(nrow(collect(sampled)), count(df))
expect_is(sampled, "DataFrame")
sampled2 <- sample(df, FALSE, 0.1)
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled2) < 3)

# Also test sample_frac
sampled3 <- sample_frac(df, FALSE, 0.1)
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled3) < 3)
})

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that returns zero or more output records from each grouping key and its values from 2
* Datasets.
*/
public interface CoGroupFunction<K, V1, V2, R> extends Serializable {
Iterable<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;

/**
* Base interface for a function used in Dataset's filter function.
*
* If the function returns true, the element is discarded in the returned Dataset.
*/
public interface FilterFunction<T> extends Serializable {
boolean call(T value) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
* A function that returns zero or more output records from each input record.
*/
public interface FlatMapFunction<T, R> extends Serializable {
public Iterable<R> call(T t) throws Exception;
Iterable<R> call(T t) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
* A function that takes two inputs and returns zero or more output records.
*/
public interface FlatMapFunction2<T1, T2, R> extends Serializable {
public Iterable<R> call(T1 t1, T2 t2) throws Exception;
Iterable<R> call(T1 t1, T2 t2) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that returns zero or more output records from each grouping key and its values.
*/
public interface FlatMapGroupFunction<K, V, R> extends Serializable {
Iterable<R> call(K key, Iterator<V> values) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;

/**
* Base interface for a function used in Dataset's foreach function.
*
* Spark will invoke the call function on each element in the input Dataset.
*/
public interface ForeachFunction<T> extends Serializable {
void call(T t) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* Base interface for a function used in Dataset's foreachPartition function.
*/
public interface ForeachPartitionFunction<T> extends Serializable {
void call(Iterator<T> t) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
* A zero-argument function that returns an R.
*/
public interface Function0<R> extends Serializable {
public R call() throws Exception;
R call() throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;

/**
* Base interface for a map function used in Dataset's map function.
*/
public interface MapFunction<T, U> extends Serializable {
U call(T value) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* Base interface for a map function used in GroupedDataset's map function.
*/
public interface MapGroupFunction<K, V, R> extends Serializable {
R call(K key, Iterator<V> values) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* Base interface for function used in Dataset's mapPartitions.
*/
public interface MapPartitionsFunction<T, U> extends Serializable {
Iterable<U> call(Iterator<T> input) throws Exception;
}
Loading

0 comments on commit dcd2883

Please sign in to comment.