Skip to content

Commit

Permalink
CrossValidator supports user-specified fold column.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 2, 2020
1 parent b9737c3 commit baec279
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 8 deletions.
Expand Up @@ -30,13 +30,13 @@ import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.util.ThreadUtils

/**
Expand All @@ -56,6 +56,19 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
def getNumFolds: Int = $(numFolds)

setDefault(numFolds -> 3)

/**
* Param for the column name of user specified fold number. Once this is specified,
* `CrossValidator` won't do random k-fold split. Note that this column should be
* integer type with range [0, numFolds) and Spark won't do sanity-check for this
* user-specified fold numbers.
*/
val foldCol: Param[String] = new Param[String](this, "foldCol",
"the column name of user specified fold number")

def getFoldCol: String = $(foldCol)

setDefault(foldCol, "")
}

/**
Expand Down Expand Up @@ -94,6 +107,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("3.1.0")
def setFoldCol(value: String): this.type = set(foldCol, value)

/**
* Set the maximum level of parallelism to evaluate models in parallel.
* Default is 1 for serial evaluation
Expand Down Expand Up @@ -132,7 +149,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, numFolds, seed, parallelism)
instr.logParams(this, numFolds, seed, parallelism, foldCol)
logTuningParams(instr)

val collectSubModelsParam = $(collectSubModels)
Expand All @@ -142,10 +159,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
} else None

// Compute metrics for each model over each split
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
val (splits, schemaWithoutFold) = if ($(foldCol) == "") {
(MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)), schema)
} else {
val filteredSchema = StructType(schema.filter(_.name != $(foldCol)).toArray)
(MLUtils.kFold(dataset.toDF, $(numFolds), $(foldCol)), filteredSchema)
}
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
val trainingDataset = sparkSession.createDataFrame(training, schemaWithoutFold).cache()
val validationDataset = sparkSession.createDataFrame(validation, schemaWithoutFold).cache()
instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")

// Fit models in a Future for training in parallel
Expand Down Expand Up @@ -183,7 +205,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
override def transformSchema(schema: StructType): StructType = {
if ($(foldCol) != "") {
val foldColDt = schema.apply($(foldCol)).dataType
require(foldColDt.isInstanceOf[IntegerType],
s"The specified `foldCol` column ${$(foldCol)} must be integer type, but got $foldColDt.")
}
transformSchemaImpl(schema)
}

@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = {
Expand Down
13 changes: 12 additions & 1 deletion mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -248,6 +248,17 @@ object MLUtils extends Logging {
}.toArray
}

/**
* Version of `kFold()` taking a fold column name.
*/
@Since("3.1.0")
def kFold(df: DataFrame, numFolds: Int, foldColName: String): Array[(RDD[Row], RDD[Row])] = {
val foldCol = df.col(foldColName)
(0 until numFolds).map { fold =>
(df.filter(foldCol =!= fold).drop(foldCol).rdd, df.filter(foldCol === fold).drop(foldCol).rdd)
}.toArray
}

/**
* Returns a new vector with `1.0` (bias) appended to the input vector.
*/
Expand Down
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.StructType

class CrossValidatorSuite
Expand All @@ -40,10 +41,22 @@ class CrossValidatorSuite
import testImplicits._

@transient var dataset: Dataset[_] = _
@transient var datasetWithFold: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
val foldCol = udf { () =>
val r = Math.random()
if (r < 0.33) {
0
} else if (r < 0.66) {
1
} else {
2
}
}
datasetWithFold = dataset.withColumn("fold", foldCol())
}

test("cross validation with logistic regression") {
Expand Down Expand Up @@ -75,6 +88,65 @@ class CrossValidatorSuite
}
}

test("cross validation with logistic regression with fold col") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
.setFoldCol("fold")
val cvModel = cv.fit(datasetWithFold)

MLTestingUtils.checkCopyAndUids(cv, cvModel)

val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)

val result = cvModel.transform(dataset).select("prediction").as[Double].collect()
testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") {
rows =>
val result2 = rows.map(_.getDouble(0))
assert(result === result2)
}
}

test("cross validation with logistic regression with wrong fold col") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
.setFoldCol("fold1")
val err1 = intercept[IllegalArgumentException] {
cv.fit(datasetWithFold)
}
assert(err1.getMessage.contains("fold1 does not exist. Available: label, features, fold"))

// Fold column must be integer type.
val foldCol = udf(() => 1L)
val datasetWithWrongFoldType = dataset.withColumn("fold1", foldCol())
val err2 = intercept[IllegalArgumentException] {
cv.fit(datasetWithWrongFoldType)
}
assert(err2
.getMessage
.contains("The specified `foldCol` column fold1 must be integer type, but got LongType."))
}

test("cross validation with linear regression") {
val dataset = sc.parallelize(
LinearDataGenerator.generateLinearInput(
Expand Down
Expand Up @@ -353,4 +353,21 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
convertMatrixColumnsFromML(df, "p._2")
}
}

test("kFold with fold column") {
val data = sc.parallelize(1 to 100, 2).map(x => (x, if (x <= 50) 0 else 1)).toDF("i", "fold")
val collectedData = data.collect().map(_.getInt(0)).sorted
val twoFoldedRdd = kFold(data, 2, "fold")
assert(twoFoldedRdd(0)._1.collect().map(_.getInt(0)).sorted ===
twoFoldedRdd(1)._2.collect().map(_.getInt(0)).sorted)
assert(twoFoldedRdd(0)._2.collect().map(_.getInt(0)).sorted ===
twoFoldedRdd(1)._1.collect().map(_.getInt(0)).sorted)

val result1 = twoFoldedRdd(0)._1.union(twoFoldedRdd(0)._2).collect().map(_.getInt(0)).sorted
assert(result1 === collectedData,
"Each training+validation set combined should contain all of the data.")
val result2 = twoFoldedRdd(1)._1.union(twoFoldedRdd(1)._2).collect().map(_.getInt(0)).sorted
assert(result2 === collectedData,
"Each training+validation set combined should contain all of the data.")
}
}

0 comments on commit baec279

Please sign in to comment.