Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
94b1003
check label data type for numeric type instead of double
BenFradet Dec 13, 2015
426b0a0
added some cases to extractLabeledPoints, looking for a better way to…
BenFradet Dec 13, 2015
8f1ba6a
Added a method to set the metadata on a dataframe
BenFradet Dec 13, 2015
d87740d
unit tests for the decision tree classifier
BenFradet Dec 13, 2015
72416b4
used the sqlcontext provided with MLlibTestSparkContext
BenFradet Dec 17, 2015
bd11dda
simpler version of extractLabeledPoints
BenFradet Dec 17, 2015
f108de7
cleanup imports for the dt classifier suite
BenFradet Dec 17, 2015
46b7e88
fixed scalastyle
BenFradet Dec 17, 2015
5d2c64d
testing other numeric type for the gbt classifier
BenFradet Dec 17, 2015
020ee31
better error message in case of non numeric type
BenFradet Dec 17, 2015
a44c02b
added unit test for non numeric type
BenFradet Dec 17, 2015
a974fb7
new set metadata method which lets you specify the label column name
BenFradet Dec 17, 2015
3725086
more concise unit tests
BenFradet Dec 17, 2015
9a54f5e
fixed scalastyle
BenFradet Dec 18, 2015
34c7e2c
forgot string interpolation
BenFradet Dec 18, 2015
e78ebc7
uts for all numeric types for logistic regression
BenFradet Dec 18, 2015
6224524
small refactor
BenFradet Dec 18, 2015
708cac1
uts for the multilayer perceptron classifier
BenFradet Dec 18, 2015
84a5d54
uts for the naive bayes classifier
BenFradet Dec 18, 2015
d7facbf
uts for the one vs rest classifier
BenFradet Dec 18, 2015
935b845
uts for the random forest classifier
BenFradet Dec 18, 2015
caf398e
uts for the decision tree regressor
BenFradet Dec 18, 2015
6a0c053
uts for the gbt regressor
BenFradet Dec 18, 2015
8659d0d
uts for linear regression
BenFradet Dec 18, 2015
4db8d0b
uts for the random forest regressor
BenFradet Dec 18, 2015
19dc889
import order
BenFradet Dec 18, 2015
0fd1281
import order - 2
BenFradet Dec 18, 2015
e48205a
import order - 3
BenFradet Dec 18, 2015
4fcfc3f
fixed import order - 2
BenFradet Jan 20, 2016
b127717
made NumericType sql private again
BenFradet Mar 11, 2016
65d922f
missed a file when rebasing
BenFradet Mar 11, 2016
ec5bd9f
import problem
BenFradet Mar 11, 2016
01b76f3
generator util functions for testing numeric and non numeric label co…
BenFradet Mar 15, 2016
d946969
used test data generators for the classifiers
BenFradet Mar 15, 2016
37bd088
used test data generators for the regressors
BenFradet Mar 15, 2016
f923c60
rmd underscore imports for the classifiers
BenFradet Mar 16, 2016
cf40601
rmd underscore imports for the regressors
BenFradet Mar 16, 2016
f5eb5e5
generalized linear regression now supports non double column
BenFradet Mar 16, 2016
4c1c53d
fixed generator for the classification suites
BenFradet Mar 16, 2016
495cfef
formatting quirks in generalized linear reg
BenFradet Mar 16, 2016
d06da2f
utility method checking if a column is of numeric type
BenFradet Mar 16, 2016
2bb7bb0
spec for aft survival regression
BenFradet Mar 16, 2016
22171ad
isotonic regression now supports all numeric types
BenFradet Mar 16, 2016
dfe57a6
spec for isotonic regression
BenFradet Mar 16, 2016
1817fee
scalastyle
BenFradet Mar 16, 2016
1fb173d
generated datasets have a censor column to fix the suite on aft survival
BenFradet Mar 17, 2016
9a2f60b
utility function to do asserts on dataframe fitted on every numeric type
BenFradet Mar 23, 2016
03b5a8a
test on linear regression
BenFradet Mar 23, 2016
d4b0616
simplified the checkAcceptAllNumericTypes fun
BenFradet Mar 24, 2016
7fdda3d
funs to check that estimators accept all numeric types and reject others
BenFradet Mar 24, 2016
d9910fd
refactored the tests for the regressors
BenFradet Mar 24, 2016
9b19aeb
refactored the tests for the classifiers
BenFradet Mar 24, 2016
3bbcd21
restored brackets
BenFradet Mar 24, 2016
e540e2f
fixed scalastyle
BenFradet Mar 24, 2016
5777511
fix for aft survival
BenFradet Mar 25, 2016
d8dcb23
rmd weird intellij formatting
BenFradet Mar 25, 2016
1782e67
single checkNumericTypes method to check whether both estimators and
BenFradet Mar 29, 2016
b679f66
classifiers tests now use the checkNumericTypes method
BenFradet Mar 29, 2016
9440eb6
regressors tests now use the checkNumericTypes method
BenFradet Mar 29, 2016
669c9e2
fixed style
BenFradet Mar 30, 2016
0868695
modified predictor params in order to reduce test time
BenFradet Mar 30, 2016
3922219
fixed style
BenFradet Mar 30, 2016
b975a03
fixed MLTestingUtils
BenFradet Mar 30, 2016
718774b
rmd useless import
BenFradet Apr 1, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params

/**
* Validates and transforms the input schema with the provided param map.
*
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
Expand All @@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
// TODO: Allow other numeric types
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
Expand Down Expand Up @@ -121,9 +121,8 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
dataset.select($(labelCol), $(featuresCol)).rdd.map {
case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -265,7 +266,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Expand Down Expand Up @@ -361,7 +362,7 @@ class LogisticRegression @Since("1.2.0") (
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
val vec = optInitialModel.get.coefficients
logWarning(
s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
s"Initial coefficients provided $vec did not match the expected size $numFeatures")
}

if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
Expand Down Expand Up @@ -522,7 +523,7 @@ class LogisticRegressionModel private[spark] (
(LogisticRegressionModel, String) = {
$(probabilityCol) match {
case "" =>
val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,12 @@ final class OneVsRest @Since("1.4.0") (

@Since("1.4.0")
override def fit(dataset: DataFrame): OneVsRestModel = {
transformSchema(dataset.schema)

// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
Expand Down Expand Up @@ -184,10 +184,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* and put it in an RDD with strong types.
*/
protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
case Row(features: Vector, label: Double, censor: Double) =>
AFTPoint(features, label, censor)
}
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
.rdd.map {
case Row(features: Vector, label: Double, censor: Double) =>
AFTPoint(features, label, censor)
}
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

/**
* Params for Generalized Linear Regression.
Expand All @@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
*
* @group param
*/
@Since("2.0.0")
Expand All @@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
*
* @group param
*/
@Since("2.0.0")
Expand Down Expand Up @@ -210,9 +212,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}

val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

if (familyObj == Gaussian && linkObj == Identity) {
Expand Down Expand Up @@ -698,7 +701,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
dataset.select(col($(labelCol)), f, w).rdd.map {
dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
Expand All @@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[Instance] = dataset.select(
col($(labelCol)), w, col($(featuresCol))).rdd.map {
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Expand Down Expand Up @@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
Expand Down Expand Up @@ -550,7 +551,7 @@ class LinearRegressionSummary private[regression] (

@transient private val metrics = new RegressionMetrics(
predictions
.select(predictionCol, labelCol)
.select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!model.getFitIntercept)
Expand Down Expand Up @@ -653,7 +654,7 @@ class LinearRegressionSummary private[regression] (
col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
}
val sigma2 = rss / degreesOfFreedom
diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
diagInvAtWA.map(_ * sigma2).map(math.sqrt)
}
}

Expand Down Expand Up @@ -826,7 +827,7 @@ private class LeastSquaresAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

if (weight == 0.0) return this

Expand Down
24 changes: 19 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.util

import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}


/**
Expand All @@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
}

/**
* Check whether the given schema contains a column of one of the require data types.
* @param colName column name
* @param dataTypes required column data types
*/
* Check whether the given schema contains a column of one of the require data types.
* @param colName column name
* @param dataTypes required column data types
*/
def checkColumnTypes(
schema: StructType,
colName: String,
Expand All @@ -60,6 +60,20 @@ private[spark] object SchemaUtils {
s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message")
}

/**
* Check whether the given schema contains a column of the numeric data type.
* @param colName column name
*/
def checkNumericType(
schema: StructType,
colName: String,
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
s"NumericType but was actually of type $actualDataType.$message")
}

/**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}

class DecisionTreeClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite
}

test("Multiclass classification tree with 10-ary (ordered) categorical features," +
" with just enough bins") {
" with just enough bins") {
val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
Expand Down Expand Up @@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite
))
val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
val dt = new DecisionTreeClassifier().setMaxDepth(3)
val model = dt.fit(df)
dt.fit(df)
}

test("Use soft prediction for binary classification with ordered categorical features") {
Expand Down Expand Up @@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite
assert(importances.toArray.forall(_ >= 0.0))
}

test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
dt, isClassification = true, sqlContext) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils


/**
* Test suite for [[GBTClassifier]].
*/
Expand Down Expand Up @@ -102,6 +101,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}

test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
gbt, isClassification = true, sqlContext) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}

// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class LogisticRegressionSuite
assert(model.hasSummary)
// Validate that we re-insert a probability column for evaluation
val fieldNames = model.summary.predictions.schema.fieldNames
assert((dataset.schema.fieldNames.toSet).subsetOf(
assert(dataset.schema.fieldNames.toSet.subsetOf(
fieldNames.toSet))
assert(fieldNames.exists(s => s.startsWith("probability_")))
}
Expand Down Expand Up @@ -934,6 +934,15 @@ class LogisticRegressionSuite
testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
checkModelData)
}

test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, isClassification = true, sqlContext) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}
}

object LogisticRegressionSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
Expand Down Expand Up @@ -162,4 +163,15 @@ class MultilayerPerceptronClassifierSuite
assert(newMlpModel.layers === mlpModel.layers)
assert(newMlpModel.weights === mlpModel.weights)
}

test("should support all NumericType labels and not support other types") {
val layers = Array(3, 2)
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
mpc, isClassification = true, sqlContext) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}
}
}
Loading