Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,88 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}

/**
* Merge the values for each key using an associative and commutative reduce function. This will
* also perform the merging locally on each mapper before sending results to a reducer, similarly
* to a "combiner" in MapReduce.
* ::Experimental::
* Return random, non-overlapping splits of this RDD sampled by key (via stratified sampling)
* with each split containing exactly math.ceil(numItems * samplingRate) for each stratum.
*
* This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in that it provides random
* splits (and their complements) instead of just a subsample of the data. This requires
* segmenting random keys into ranges with upper and lower bounds instead of segmenting the keys
* into a high/low bisection of the entire dataset.
*
* @param weights array of maps of (key -> samplingRate) pairs for each split, normed by key
* @param exact boolean specifying whether to use exact subsampling
* @param seed seed for the random number generator
* @return array of tuples containing the subsample and complement RDDs for each split
*/
@Experimental
def randomSplitByKey(
weights: Array[Map[K, Double]],
exact: Boolean = false,
seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, V)])] = self.withScope {

require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.")
if (weights.length > 1) {
require(weights.map(m => m.keys.toSet).sliding(2).forall(t => t(0) == t(1)),
"Inconsistent keys between splits.")
}

// maps of sampling threshold boundaries at 0.0 and 1.0
val leftBoundary = weights(0).map(x => (x._1, 0.0))
val rightBoundary = weights(0).map(x => (x._1, 1.0))

// normalize and cumulative sum
val cumWeightsByKey = weights.scanLeft(leftBoundary) { case (accMap, iterMap) =>
accMap.map { case (k, v) => (k, v + iterMap(k)) }
}.drop(1)

val weightSumsByKey = cumWeightsByKey.last
val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) =>
val keyWeightSum = weightSumsByKey(key)
val norm = if (keyWeightSum > 0.0) keyWeightSum else 1.0
(key, threshold / norm)
})

// compute exact thresholds for each stratum if required
val splitPoints = if (exact) {
normedCumWeightsByKey.map { w =>
val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, w, None, seed)
StratifiedSamplingUtils.computeThresholdByKey(finalResult, w)
}
} else {
normedCumWeightsByKey
}

val splitsPointsAndBounds = leftBoundary +: splitPoints :+ rightBoundary
splitsPointsAndBounds.sliding(2).map { x =>
(randomSampleByKeyWithRange(x(0), x(1), seed),
randomSampleByKeyWithRange(x(0), x(1), seed, complement = true))
}.toArray
}

/**
* Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given
* probability bounds for each stratum.
*
* @param lb map of lower bound for each key to use for the Bernoulli cell sampler
* @param ub map of upper bound for each key to use for the Bernoulli cell sampler
* @param seed the seed for the Random number generator
* @param complement boolean specifying whether to return subsample or its complement
* @return A random, stratified sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleByKeyWithRange(lb: Map[K, Double],
ub: Map[K, Double],
seed: Long,
complement: Boolean = false): RDD[(K, V)] = {
val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self,
lb, ub, seed, complement)
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}

/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope {
combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,35 @@ private[spark] object StratifiedSamplingUtils extends Logging {
}
}

/**
* Return the per partition sampling function used for partitioning a dataset without
* replacement.
*
* The sampling function has a unique seed per partition.
*/
def getBernoulliCellSamplingFunction[K, V](rdd: RDD[(K, V)],
lb: Map[K, Double],
ub: Map[K, Double],
seed: Long,
complement: Boolean = false): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
(idx: Int, iter: Iterator[(K, V)]) => {
val rng = new RandomDataGenerator()
rng.reSeed(seed + idx)

if (complement) {
iter.filter { case(k, _) =>
val x = rng.nextUniform()
(x < lb(k)) || (x >= ub(k))
}
} else {
iter.filter { case(k, _) =>
val x = rng.nextUniform()
(x >= lb(k)) && (x < ub(k))
}
}
}
}

/**
* Return the per partition sampling function used for sampling with replacement.
*
Expand Down
183 changes: 183 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,118 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
}
}

test("randomSplitByKey exact") {
val defaultSeed = 1L

// vary RDD size
for (n <- List(100, 1000, 10000)) {
val data = sc.parallelize(1 to n, 2)
val fractionPositive = 0.3
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true)
}

// vary fractionPositive
for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
val n = 100
val data = sc.parallelize(1 to n, 2)
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true)
}

// use same data for remaining tests
val n = 100
val fractionPositive = 0.3
val data = sc.parallelize(1 to n, 2)
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()

// use different weights for each key in the split
val unevenWeights: Array[scala.collection.Map[String, Double]] =
Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3))
StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, true)

// vary the seed
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
for (seed <- defaultSeed to defaultSeed + 3L) {
StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, true)
}

// vary the number of splits
for (numSplits <- 1 to 3) {
val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true)
}
}

test("randomSplitByKey") {
val defaultSeed = 1L

// vary RDD size
for (n <- List(100, 1000, 10000)) {
val data = sc.parallelize(1 to n, 2)
val fractionPositive = 0.3
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false)
}

// vary fractionPositive
for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
val n = 100
val data = sc.parallelize(1 to n, 2)
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false)
}

// use same data for remaining tests
val n = 100
val fractionPositive = 0.3
val data = sc.parallelize(1 to n, 2)
val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
val keys = stratifiedData.keys.distinct().collect()

// use different weights for each key in the split
val unevenWeights: Array[scala.collection.Map[String, Double]] =
Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3))
StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, false)

// vary the seed
val splitWeights = Array(0.3, 0.2, 0.5)
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
for (seed <- defaultSeed to defaultSeed + 5L) {
StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, false)
}

// vary the number of splits
for (numSplits <- 1 to 5) {
val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too
val weights: Array[scala.collection.Map[String, Double]] =
splitWeights.map(w => keys.map(k => (k, w)).toMap)
StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false)
}
}

test("reduceByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_ + _).collect()
Expand Down Expand Up @@ -646,6 +758,19 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
}
}

def checkSplitSize(exact: Boolean,
expected: Long,
actual: Long,
p: Double): Boolean = {
if (exact) {
// all splits will not be exact, but must be within 1 of expected size
return math.abs(expected - actual) <= 1
}
val stdev = math.sqrt(expected * p * (1 - p))
// Very forgiving margin since we're dealing with very small sample sizes most of the time
math.abs(actual - expected) <= 6 * stdev
}

def testSampleExact(stratifiedData: RDD[(String, Int)],
samplingRate: Double,
seed: Long,
Expand All @@ -662,6 +787,64 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
testPoisson(stratifiedData, false, samplingRate, seed, n)
}

def testSplits(stratifiedData: RDD[(String, Int)],
weights: Array[scala.collection.Map[String, Double]],
seed: Long,
n: Int,
exact: Boolean): Unit = {
val baseFold = weights(0).map(x => (x._1, 0.0))
val totalWeightByKey = weights.foldLeft(baseFold) { case (accMap, iterMap) =>
accMap.map { case (k, v) => (k, v + iterMap(k)) }
}
val normedWeights = weights.map(m => m.map { case(k, v) => (k, v / totalWeightByKey(k))})

val splits = stratifiedData.randomSplitByKey(weights, exact, seed)
val stratCounts = stratifiedData.countByKey()


val expectedSampleSizes = normedWeights.map { m =>
stratCounts.map { case (key, count) =>
(key, math.ceil(count * m(key)).toLong)
}.toMap
}
val expectedComplementSizes = normedWeights.map { m =>
stratCounts.map { case (key, count) =>
(key, math.ceil(count * (1 - m(key))).toLong)
}.toMap
}

val samples = splits.map{ case(subsample, complement) => subsample.collect()}
val complements = splits.map{ case(subsample, complement) => complement.collect()}

// check for the correct sample size for each split by key
(samples.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedSampleSizes)
.zipWithIndex.foreach { case ((actual, expected), idx) =>
actual.foreach { case (k, v) =>
checkSplitSize(exact, expected(k), v, normedWeights(idx)(k))
}
}
(complements.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedComplementSizes)
.zipWithIndex.foreach { case ((actual, expected), idx) =>
actual.foreach{ case (k, v) =>
checkSplitSize(exact, expected(k), v, normedWeights(idx)(k))
}
}

// make sure samples ++ complements equals the original set
(samples zip complements).foreach { case (sample, complement) =>
assert((sample ++ complement).sortBy(_._2).toList == stratifiedData.collect().toList)
}

// make sure the elements are members of the original set
samples.map(sample => sample.map(x => assert(x._2 >= 1 && x._2 <= n)))

// make sure no duplicates in each sample
samples.map(sample => assert(sample.length == sample.toSet.size))

// make sure that union of all samples equals the original set
assert(samples.flatMap(x => x).sortBy(_._2).toList == stratifiedData.collect().toList)
}

// Without replacement validation
def testBernoulli(stratifiedData: RDD[(String, Int)],
exact: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,23 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
val numFolds: IntParam = new IntParam(this, "numFolds",
"number of folds for cross validation (>= 2)", ParamValidators.gtEq(2))

/**
* Param for stratified sampling column name
* Default: "None"
* @group param
*/
val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol",
"stratified column name")

/** @group getParam */
def getStratifiedCol: String = $(stratifiedCol)

/** @group getParam */
def getNumFolds: Int = $(numFolds)

setDefault(numFolds -> 3)
setDefault(stratifiedCol -> "None")

}

/**
Expand Down Expand Up @@ -91,6 +104,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.0.0")
def setStratifiedCol(value: String): this.type = set(stratifiedCol, value)

@Since("1.4.0")
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
Expand All @@ -101,7 +118,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))

val splits = if (schema.fieldNames.contains($(stratifiedCol)) & isSet(stratifiedCol)) {
val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol))
val pairData = dataset.rdd.map(row => (row(stratifiedColIndex), row))
val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), 0)
splitsWithKeys.map { case (training, validation) => (training.values, validation.values)}
} else {
if (isSet(stratifiedCol)) logWarning(s"Stratified column does not exist. " +
s"Performing regular k-fold subsampling.")
// regular kFold
MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
}

splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
Expand Down
Loading