diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 296179b75bc43..6a916a90156c8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -301,9 +301,88 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative and commutative reduce function. This will - * also perform the merging locally on each mapper before sending results to a reducer, similarly - * to a "combiner" in MapReduce. + * ::Experimental:: + * Return random, non-overlapping splits of this RDD sampled by key (via stratified sampling) + * with each split containing exactly math.ceil(numItems * samplingRate) for each stratum. + * + * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in that it provides random + * splits (and their complements) instead of just a subsample of the data. This requires + * segmenting random keys into ranges with upper and lower bounds instead of segmenting the keys + * into a high/low bisection of the entire dataset. + * + * @param weights array of maps of (key -> samplingRate) pairs for each split, normed by key + * @param exact boolean specifying whether to use exact subsampling + * @param seed seed for the random number generator + * @return array of tuples containing the subsample and complement RDDs for each split + */ + @Experimental + def randomSplitByKey( + weights: Array[Map[K, Double]], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, V)])] = self.withScope { + + require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.") + if (weights.length > 1) { + require(weights.map(m => m.keys.toSet).sliding(2).forall(t => t(0) == t(1)), + "Inconsistent keys between splits.") + } + + // maps of sampling threshold boundaries at 0.0 and 1.0 + val leftBoundary = weights(0).map(x => (x._1, 0.0)) + val rightBoundary = weights(0).map(x => (x._1, 1.0)) + + // normalize and cumulative sum + val cumWeightsByKey = weights.scanLeft(leftBoundary) { case (accMap, iterMap) => + accMap.map { case (k, v) => (k, v + iterMap(k)) } + }.drop(1) + + val weightSumsByKey = cumWeightsByKey.last + val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) => + val keyWeightSum = weightSumsByKey(key) + val norm = if (keyWeightSum > 0.0) keyWeightSum else 1.0 + (key, threshold / norm) + }) + + // compute exact thresholds for each stratum if required + val splitPoints = if (exact) { + normedCumWeightsByKey.map { w => + val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, w, None, seed) + StratifiedSamplingUtils.computeThresholdByKey(finalResult, w) + } + } else { + normedCumWeightsByKey + } + + val splitsPointsAndBounds = leftBoundary +: splitPoints :+ rightBoundary + splitsPointsAndBounds.sliding(2).map { x => + (randomSampleByKeyWithRange(x(0), x(1), seed), + randomSampleByKeyWithRange(x(0), x(1), seed, complement = true)) + }.toArray + } + + /** + * Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given + * probability bounds for each stratum. + * + * @param lb map of lower bound for each key to use for the Bernoulli cell sampler + * @param ub map of upper bound for each key to use for the Bernoulli cell sampler + * @param seed the seed for the Random number generator + * @param complement boolean specifying whether to return subsample or its complement + * @return A random, stratified sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleByKeyWithRange(lb: Map[K, Double], + ub: Map[K, Double], + seed: Long, + complement: Boolean = false): RDD[(K, V)] = { + val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self, + lb, ub, seed, complement) + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 67822749112c6..dfba6b2b1481d 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -215,6 +215,35 @@ private[spark] object StratifiedSamplingUtils extends Logging { } } + /** + * Return the per partition sampling function used for partitioning a dataset without + * replacement. + * + * The sampling function has a unique seed per partition. + */ + def getBernoulliCellSamplingFunction[K, V](rdd: RDD[(K, V)], + lb: Map[K, Double], + ub: Map[K, Double], + seed: Long, + complement: Boolean = false): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + + if (complement) { + iter.filter { case(k, _) => + val x = rng.nextUniform() + (x < lb(k)) || (x >= ub(k)) + } + } else { + iter.filter { case(k, _) => + val x = rng.nextUniform() + (x >= lb(k)) && (x < ub(k)) + } + } + } + } + /** * Return the per partition sampling function used for sampling with replacement. * diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef4..4de7168b13adc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,6 +168,118 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + test("randomSplitByKey exact") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // use same data for remaining tests + val n = 100 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, true) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 3L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, true) + } + + // vary the number of splits + for (numSplits <- 1 to 3) { + val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + } + + test("randomSplitByKey") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // use same data for remaining tests + val n = 100 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, false) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 5L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, false) + } + + // vary the number of splits + for (numSplits <- 1 to 5) { + val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + } + test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_ + _).collect() @@ -646,6 +758,19 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + def checkSplitSize(exact: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + // all splits will not be exact, but must be within 1 of expected size + return math.abs(expected - actual) <= 1 + } + val stdev = math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, @@ -662,6 +787,64 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { testPoisson(stratifiedData, false, samplingRate, seed, n) } + def testSplits(stratifiedData: RDD[(String, Int)], + weights: Array[scala.collection.Map[String, Double]], + seed: Long, + n: Int, + exact: Boolean): Unit = { + val baseFold = weights(0).map(x => (x._1, 0.0)) + val totalWeightByKey = weights.foldLeft(baseFold) { case (accMap, iterMap) => + accMap.map { case (k, v) => (k, v + iterMap(k)) } + } + val normedWeights = weights.map(m => m.map { case(k, v) => (k, v / totalWeightByKey(k))}) + + val splits = stratifiedData.randomSplitByKey(weights, exact, seed) + val stratCounts = stratifiedData.countByKey() + + + val expectedSampleSizes = normedWeights.map { m => + stratCounts.map { case (key, count) => + (key, math.ceil(count * m(key)).toLong) + }.toMap + } + val expectedComplementSizes = normedWeights.map { m => + stratCounts.map { case (key, count) => + (key, math.ceil(count * (1 - m(key))).toLong) + }.toMap + } + + val samples = splits.map{ case(subsample, complement) => subsample.collect()} + val complements = splits.map{ case(subsample, complement) => complement.collect()} + + // check for the correct sample size for each split by key + (samples.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedSampleSizes) + .zipWithIndex.foreach { case ((actual, expected), idx) => + actual.foreach { case (k, v) => + checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) + } + } + (complements.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedComplementSizes) + .zipWithIndex.foreach { case ((actual, expected), idx) => + actual.foreach{ case (k, v) => + checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) + } + } + + // make sure samples ++ complements equals the original set + (samples zip complements).foreach { case (sample, complement) => + assert((sample ++ complement).sortBy(_._2).toList == stratifiedData.collect().toList) + } + + // make sure the elements are members of the original set + samples.map(sample => sample.map(x => assert(x._2 >= 1 && x._2 <= n))) + + // make sure no duplicates in each sample + samples.map(sample => assert(sample.length == sample.toSet.size)) + + // make sure that union of all samples equals the original set + assert(samples.flatMap(x => x).sortBy(_._2).toList == stratifiedData.collect().toList) + } + // Without replacement validation def testBernoulli(stratifiedData: RDD[(String, Int)], exact: Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 963f81cb3ec39..cf838913b1c83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -50,10 +50,23 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) + /** + * Param for stratified sampling column name + * Default: "None" + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") + + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + /** @group getParam */ def getNumFolds: Int = $(numFolds) setDefault(numFolds -> 3) + setDefault(stratifiedCol -> "None") + } /** @@ -91,6 +104,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.0.0") + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema @@ -101,7 +118,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + + val splits = if (schema.fieldNames.contains($(stratifiedCol)) & isSet(stratifiedCol)) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) + val pairData = dataset.rdd.map(row => (row(stratifiedColIndex), row)) + val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), 0) + splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} + } else { + if (isSet(stratifiedCol)) logWarning(s"Stratified column does not exist. " + + s"Performing regular k-fold subsampling.") + // regular kFold + MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + } + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 70fa5f0234753..36e1981af0523 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame @@ -38,10 +39,22 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) + /** + * Param for stratified sampling column name + * Default: "None" + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") + /** @group getParam */ def getTrainRatio: Double = $(trainRatio) + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + setDefault(trainRatio -> 0.75) + setDefault(stratifiedCol -> "None") } /** @@ -76,6 +89,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + /** @group setParam */ + @Since("2.0.0") + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema @@ -88,7 +105,21 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val metrics = new Array[Double](epm.length) val Array(training, validation) = - dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) + if (schema.fieldNames.contains($(stratifiedCol)) & isSet(stratifiedCol)) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) + val pairData = dataset.rdd.map(row => (row(stratifiedColIndex), row)) + val keys = pairData.keys.distinct().collect() + val weights: Array[scala.collection.Map[Any, Double]] = + Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap) + val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, 0) + splitsWithKeys.map { case (subsample, complement) => subsample.values } + } else { + if (isSet(stratifiedCol)) { + logWarning("Stratified column not found. Using standard split.") + } + dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) + } + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 74e9271e40329..92a5da37dce4a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -283,6 +283,24 @@ object MLUtils { }.toArray } + /** + * Return a k element array of pairs of RDDs with the first element of each pair + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * The training and validation data are stratified by the key of the rdd, and the key + * ratios in the original data are maintained in each stratum of the train and validation + * data. + */ + def kFoldStratified[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + numFolds: Int, + seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { + val keys = rdd.keys.distinct().collect() + val weights: Array[scala.collection.Map[K, Double]] = (1 to numFolds).map { + n => keys.map(k => (k, 1 / numFolds.toDouble)).toMap + }.toArray + rdd.randomSplitByKey(weights, exact = true, seed) + } + /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7af3c6d6ede47..77ad97f9aa15f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -55,6 +55,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) + .setStratifiedCol("label") val cvModel = cv.fit(dataset) // copied model must have the same paren. @@ -109,6 +110,8 @@ class CrossValidatorSuite .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) + .setNumFolds(3) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cf8dcefebc3aa..d21ce71256e75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -45,6 +45,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setStratifiedCol("label") val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -97,6 +98,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index e542f21a1802c..f4c248c55c534 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -208,6 +208,37 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("kFoldStratified") { + /** + * Most of the functionality of [[kFoldStratified]] is tested in the PairRDD function + * `randomSplitByKey`. All that needs to be checked here is that the folds are even + * splits for each key. + */ + val defaultSeed = 1 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val keys = Array("0", "1") + val stratifiedData = data.map { x => + if (x > n*fractionPositive) ("0", x) else ("1", x) + } + val counts = stratifiedData.countByKey() + for (numFolds <- 1 to 3) { + val folds = kFoldStratified(stratifiedData, numFolds, defaultSeed) + val expectedSize = keys.map(k => (k, counts(k) / numFolds.toDouble)).toMap + for ((sample, complement) <- folds) { + val sampleCounts = sample.countByKey() + val complementCounts = complement.countByKey() + sampleCounts.foreach { case(key, count) => + assert(math.abs(count - expectedSize(key)) <= 1) + } + complementCounts.foreach { case(key, count) => + assert(math.abs(count - (counts(key) - expectedSize(key))) <= 1) + } + } + } + } + test("loadVectors") { val vectors = sc.parallelize(Seq( Vectors.dense(1.0, 2.0),