From 022c8367a28ade4529d522e4fffe0896e75336da Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 6 Feb 2016 16:11:26 +0800 Subject: [PATCH 1/5] add stratifiedSampling and ut --- .../spark/ml/feature/StratifiedSampling.scala | 129 ++++++++++++++++++ .../ml/feature/StratifiedSamplingSuite.scala | 54 ++++++++ 2 files changed, 183 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala new file mode 100644 index 0000000000000..dfbd4a6fd8762 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.feature.StratifiedSampling.StratifiedSamplingWriter +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * + * Stratified sampling on the DataFrame according to the keys in a specific label column. User + * can set 'fraction' to set different sampling rate for each key. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of the items + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + */ +@Experimental +final class StratifiedSampling private( + override val uid: String, + val withReplacement: Boolean, + val fraction: Map[String, Double]) + extends Transformer with HasLabelCol with HasSeed with DefaultParamsWritable { + + @Since("2.0.0") + def this(withReplacement: Boolean, fraction: Map[String, Double]) = + this(Identifiable.randomUID("stratifiedSampling"), withReplacement, fraction) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("2.0.0") + def setLabel(value: String): this.type = set(labelCol, value) + + setDefault(seed -> Utils.random.nextLong) + + @Since("2.0.0") + override def transform(data: DataFrame): DataFrame = { + transformSchema(data.schema) + val schema = data.schema + val colId = schema.fieldIndex($(labelCol)) + val result = data.rdd.map(r => (r.get(colId), r)) + .sampleByKey(withReplacement, fraction.toMap, $(seed)) + .map(_._2) + data.sqlContext.createDataFrame(result, schema) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(labelCol), StringType) + schema + } + + @Since("2.0.0") + override def write: MLWriter = new StratifiedSamplingWriter(this) + + @Since("2.0.0") + override def copy(extra: ParamMap): StratifiedSampling = { + val copied = new StratifiedSampling(uid, withReplacement, fraction) + copyValues(copied, extra) + } +} + +@Since("2.0.0") +object StratifiedSampling extends DefaultParamsReadable[StratifiedSampling] { + + private case class Data(withReplacement: Boolean, fraction: Map[String, Double]) + + private[StratifiedSampling] + class StratifiedSamplingWriter(instance: StratifiedSampling) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.withReplacement, instance.fraction) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StratifiedSamplingReader extends MLReader[StratifiedSampling] { + + private val className = classOf[StratifiedSampling].getName + + override def load(path: String): StratifiedSampling = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(withReplacement: Boolean, fraction: Map[String, Double]) = sqlContext.read + .parquet(dataPath) + .select("withReplacement", "fraction") + .head() + val model = new StratifiedSampling(metadata.uid, withReplacement, fraction) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.0.0") + override def read: MLReader[StratifiedSampling] = new StratifiedSamplingReader + + @Since("2.0.0") + override def load(path: String): StratifiedSampling = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala new file mode 100644 index 0000000000000..a8f2be58dbf75 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StratifiedSamplingSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + + test("StratifiedSampling on String label") { + Logger.getRootLogger.setLevel(Level.WARN) + val df = sqlContext.createDataFrame(Seq( + (0, "0"), + (0, "0"), + (1, "1"), + (1, "1") + )).toDF("label", "str") + val map = Map("1" -> 0.5, "2" -> 0.1) + val trans = new StratifiedSampling(false, map).setLabel("str") + trans.transform(df).schema == df.schema + } + + test("StratifiedSampling read/write") { + val t = new StratifiedSampling(false, Map("1" -> 0.5, "2" -> 0.1)) + .setLabel("myLabel") + val newInstance = testDefaultReadWrite(t) + assert(t.withReplacement == newInstance.withReplacement && + t.fraction == newInstance.fraction) + } +} From 78c80d78afa94c6522f81339ffb6a6bef315eee0 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 13 Feb 2016 01:14:50 -0800 Subject: [PATCH 2/5] rename to sampler and use sampleby --- ...Sampling.scala => StratifiedSampler.scala} | 56 +++++++++++-------- .../ml/feature/StratifiedSamplingSuite.scala | 4 +- 2 files changed, 34 insertions(+), 26 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{StratifiedSampling.scala => StratifiedSampler.scala} (71%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala similarity index 71% rename from mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala index dfbd4a6fd8762..e73cace008cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampling.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala @@ -14,26 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.ml.feature +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.feature.StratifiedSampling.StratifiedSamplingWriter import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.Utils /** * :: Experimental :: * * Stratified sampling on the DataFrame according to the keys in a specific label column. User - * can set 'fraction' to set different sampling rate for each key. + * can set 'fraction' to assign different sampling rate for each key. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of the items @@ -41,16 +40,22 @@ import org.apache.spark.util.Utils * with replacement: expected number of times each element is chosen; fraction must be >= 0 */ @Experimental -final class StratifiedSampling private( +final class StratifiedSampler private ( override val uid: String, val withReplacement: Boolean, val fraction: Map[String, Double]) extends Transformer with HasLabelCol with HasSeed with DefaultParamsWritable { + import StratifiedSampler._ + @Since("2.0.0") def this(withReplacement: Boolean, fraction: Map[String, Double]) = this(Identifiable.randomUID("stratifiedSampling"), withReplacement, fraction) + @Since("2.0.0") + def this(withReplacement: Boolean, fraction: java.util.Map[String, Double]) = + this(Identifiable.randomUID("stratifiedSampling"), withReplacement, fraction.asScala.toMap) + /** @group setParam */ @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) @@ -59,17 +64,20 @@ final class StratifiedSampling private( @Since("2.0.0") def setLabel(value: String): this.type = set(labelCol, value) - setDefault(seed -> Utils.random.nextLong) - @Since("2.0.0") override def transform(data: DataFrame): DataFrame = { - transformSchema(data.schema) + transformSchema(data.schema, logging = true) val schema = data.schema - val colId = schema.fieldIndex($(labelCol)) - val result = data.rdd.map(r => (r.get(colId), r)) - .sampleByKey(withReplacement, fraction.toMap, $(seed)) - .map(_._2) - data.sqlContext.createDataFrame(result, schema) + if(withReplacement){ + val colId = schema.fieldIndex($(labelCol)) + val result = data.rdd.map(r => (r.get(colId), r)) + .sampleByKey(withReplacement, fraction.toMap, $(seed)) + .map(_._2) + data.sqlContext.createDataFrame(result, schema) + } + else { + data.stat.sampleBy($(labelCol), fraction, $(seed)) + } } @Since("2.0.0") @@ -82,19 +90,19 @@ final class StratifiedSampling private( override def write: MLWriter = new StratifiedSamplingWriter(this) @Since("2.0.0") - override def copy(extra: ParamMap): StratifiedSampling = { - val copied = new StratifiedSampling(uid, withReplacement, fraction) + override def copy(extra: ParamMap): StratifiedSampler = { + val copied = new StratifiedSampler(uid, withReplacement, fraction) copyValues(copied, extra) } } @Since("2.0.0") -object StratifiedSampling extends DefaultParamsReadable[StratifiedSampling] { +object StratifiedSampler extends DefaultParamsReadable[StratifiedSampler] { private case class Data(withReplacement: Boolean, fraction: Map[String, Double]) - private[StratifiedSampling] - class StratifiedSamplingWriter(instance: StratifiedSampling) extends MLWriter { + private[StratifiedSampler] + class StratifiedSamplingWriter(instance: StratifiedSampler) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -104,26 +112,26 @@ object StratifiedSampling extends DefaultParamsReadable[StratifiedSampling] { } } - private class StratifiedSamplingReader extends MLReader[StratifiedSampling] { + private class StratifiedSamplingReader extends MLReader[StratifiedSampler] { - private val className = classOf[StratifiedSampling].getName + private val className = classOf[StratifiedSampler].getName - override def load(path: String): StratifiedSampling = { + override def load(path: String): StratifiedSampler = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val Row(withReplacement: Boolean, fraction: Map[String, Double]) = sqlContext.read .parquet(dataPath) .select("withReplacement", "fraction") .head() - val model = new StratifiedSampling(metadata.uid, withReplacement, fraction) + val model = new StratifiedSampler(metadata.uid, withReplacement, fraction) DefaultParamsReader.getAndSetParams(model, metadata) model } } @Since("2.0.0") - override def read: MLReader[StratifiedSampling] = new StratifiedSamplingReader + override def read: MLReader[StratifiedSampler] = new StratifiedSamplingReader @Since("2.0.0") - override def load(path: String): StratifiedSampling = super.load(path) + override def load(path: String): StratifiedSampler = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala index a8f2be58dbf75..fdfa2c7887db0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala @@ -40,12 +40,12 @@ class StratifiedSamplingSuite (1, "1") )).toDF("label", "str") val map = Map("1" -> 0.5, "2" -> 0.1) - val trans = new StratifiedSampling(false, map).setLabel("str") + val trans = new StratifiedSampler(false, map).setLabel("str") trans.transform(df).schema == df.schema } test("StratifiedSampling read/write") { - val t = new StratifiedSampling(false, Map("1" -> 0.5, "2" -> 0.1)) + val t = new StratifiedSampler(false, Map("1" -> 0.5, "2" -> 0.1)) .setLabel("myLabel") val newInstance = testDefaultReadWrite(t) assert(t.withReplacement == newInstance.withReplacement && From 853fb967ae845aabb6f70b58b339d2e2152c1906 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 14 Feb 2016 00:44:49 -0800 Subject: [PATCH 3/5] fix type warning --- .../scala/org/apache/spark/ml/feature/StratifiedSampler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala index e73cace008cc9..511696a57e75b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala @@ -119,10 +119,12 @@ object StratifiedSampler extends DefaultParamsReadable[StratifiedSampler] { override def load(path: String): StratifiedSampler = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(withReplacement: Boolean, fraction: Map[String, Double]) = sqlContext.read + val data = sqlContext.read .parquet(dataPath) .select("withReplacement", "fraction") .head() + val withReplacement = data.getBoolean(0) + val fraction = data.getAs[Map[String, Double]](1) val model = new StratifiedSampler(metadata.uid, withReplacement, fraction) DefaultParamsReader.getAndSetParams(model, metadata) model From 525a80d116ee9bf106118bfbfb1bcd8c67a5d62e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 14 Feb 2016 11:49:08 -0800 Subject: [PATCH 4/5] add see and rename suite --- .../scala/org/apache/spark/ml/feature/StratifiedSampler.scala | 4 +++- ...tifiedSamplingSuite.scala => StratifiedSamplerSuite.scala} | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) rename mllib/src/test/scala/org/apache/spark/ml/feature/{StratifiedSamplingSuite.scala => StratifiedSamplerSuite.scala} (98%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala index 511696a57e75b..b4abc377b49be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, DataFrameStatFunctions} import org.apache.spark.sql.types.{StringType, StructType} /** @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.{StringType, StructType} * Stratified sampling on the DataFrame according to the keys in a specific label column. User * can set 'fraction' to assign different sampling rate for each key. * + * @see [[DataFrameStatFunctions#sampleBy(java.lang.String, java.util.Map, long)]] + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of the items * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala similarity index 98% rename from mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala index fdfa2c7887db0..870c771da7621 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -class StratifiedSamplingSuite +class StratifiedSamplerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { From 8f8e74797c8699881741be2d492560e4c49ebb7f Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 13 Apr 2016 12:53:11 -0400 Subject: [PATCH 5/5] support any --- .../spark/ml/feature/StratifiedSampler.scala | 134 ++++++++++-------- .../ml/feature/StratifiedSamplerSuite.scala | 28 ++-- 2 files changed, 91 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala index b4abc377b49be..d7c7fa8255898 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StratifiedSampler.scala @@ -1,62 +1,58 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.feature +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ -import scala.collection.JavaConverters._ +package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path +import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{BooleanParam, ParamMap} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, DataFrameStatFunctions} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.{DataFrame, DataFrameStatFunctions, Dataset} +import org.apache.spark.sql.types.StructType /** * :: Experimental :: * * Stratified sampling on the DataFrame according to the keys in a specific label column. User - * can set 'fraction' to assign different sampling rate for each key. - * - * @see [[DataFrameStatFunctions#sampleBy(java.lang.String, java.util.Map, long)]] + * can set 'fraction' to set different sampling rate for each key. * - * @param withReplacement can elements be sampled multiple times (replaced when sampled out) - * @param fraction expected size of the sample as a fraction of the items - * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param fractions sampling fraction for each stratum. @see [[DataFrameStatFunctions.sampleBy]]. + * Supported stratum types are Int, String and Boolean */ @Experimental final class StratifiedSampler private ( override val uid: String, - val withReplacement: Boolean, - val fraction: Map[String, Double]) + val fractions: Map[_, Double]) extends Transformer with HasLabelCol with HasSeed with DefaultParamsWritable { import StratifiedSampler._ @Since("2.0.0") - def this(withReplacement: Boolean, fraction: Map[String, Double]) = - this(Identifiable.randomUID("stratifiedSampling"), withReplacement, fraction) + def this(fraction: Map[_, Double]) = + this(Identifiable.randomUID("stratifiedSampling"), fraction) @Since("2.0.0") - def this(withReplacement: Boolean, fraction: java.util.Map[String, Double]) = - this(Identifiable.randomUID("stratifiedSampling"), withReplacement, fraction.asScala.toMap) + def this(fraction: java.util.Map[_, Double]) = + this(fraction.asScala.toMap) /** @group setParam */ @Since("2.0.0") @@ -66,25 +62,42 @@ final class StratifiedSampler private ( @Since("2.0.0") def setLabel(value: String): this.type = set(labelCol, value) + /** + * If true, sampling will be skipped and all the records will be returned. + * Used in prediction pipeline + * Default: false + * @group param + */ + val skip: BooleanParam = new BooleanParam(this, "skip", + "If true, sampling will be skipped and all the records will be returned. " + + "Used in prediction pipeline") + + /** @group getParam */ + def getSkip: Boolean = $(skip) + + /** @group setParam */ + def setSkip(value: Boolean): this.type = set(skip, value) + + setDefault(skip -> false) + @Since("2.0.0") - override def transform(data: DataFrame): DataFrame = { + override def transform(data: Dataset[_]): DataFrame = { transformSchema(data.schema, logging = true) - val schema = data.schema - if(withReplacement){ - val colId = schema.fieldIndex($(labelCol)) - val result = data.rdd.map(r => (r.get(colId), r)) - .sampleByKey(withReplacement, fraction.toMap, $(seed)) - .map(_._2) - data.sqlContext.createDataFrame(result, schema) - } - else { - data.stat.sampleBy($(labelCol), fraction, $(seed)) + if (!$(skip)) { + data.stat.sampleBy($(labelCol), fractions, $(seed)) + } else { + data.toDF() } } @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(labelCol), StringType) + require(fractions.nonEmpty, "fraction should not be empty") + require(fractions.keySet.forall(_.isInstanceOf[String]) + || fractions.keySet.forall(_.isInstanceOf[Int]) + || fractions.keySet.forall(_.isInstanceOf[Boolean]), + s"only support stratum of type String, Int and Boolean") + require(fractions.values.forall(v => v >= 0 && v <= 1), "sampling rate should be in [0, 1]") schema } @@ -93,7 +106,7 @@ final class StratifiedSampler private ( @Since("2.0.0") override def copy(extra: ParamMap): StratifiedSampler = { - val copied = new StratifiedSampler(uid, withReplacement, fraction) + val copied = new StratifiedSampler(uid, fractions) copyValues(copied, extra) } } @@ -101,33 +114,34 @@ final class StratifiedSampler private ( @Since("2.0.0") object StratifiedSampler extends DefaultParamsReadable[StratifiedSampler] { - private case class Data(withReplacement: Boolean, fraction: Map[String, Double]) - private[StratifiedSampler] class StratifiedSamplingWriter(instance: StratifiedSampler) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = new Data(instance.withReplacement, instance.fraction) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + val map = instance.fractions + val df = map.keys.head match { + case s: String => + sqlContext.createDataFrame(map.asInstanceOf[Map[String, Double]].toSeq) + case i: Int => + sqlContext.createDataFrame(map.asInstanceOf[Map[Int, Double]].toSeq) + case b: Boolean => + sqlContext.createDataFrame(map.asInstanceOf[Map[Boolean, Double]].toSeq) + case _ => throw new SparkException("wrong type") + } + df.toDF("key", "value").repartition(1).write.parquet(dataPath) } } private class StratifiedSamplingReader extends MLReader[StratifiedSampler] { - private val className = classOf[StratifiedSampler].getName - override def load(path: String): StratifiedSampler = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read - .parquet(dataPath) - .select("withReplacement", "fraction") - .head() - val withReplacement = data.getBoolean(0) - val fraction = data.getAs[Map[String, Double]](1) - val model = new StratifiedSampler(metadata.uid, withReplacement, fraction) + val fraction = sqlContext.read.parquet(dataPath).select("key", "value") + .rdd.map(r => (r.get(0), r.getDouble(1))).collectAsMap().toMap + val model = new StratifiedSampler(metadata.uid, fraction.asInstanceOf[Map[_, Double]]) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -139,3 +153,5 @@ object StratifiedSampler extends DefaultParamsReadable[StratifiedSampler] { @Since("2.0.0") override def load(path: String): StratifiedSampler = super.load(path) } + + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala index 870c771da7621..8b09824f0e8b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StratifiedSamplerSuite.scala @@ -31,24 +31,28 @@ class StratifiedSamplerSuite ParamsSuite.checkParams(new Binarizer) } - test("StratifiedSampling on String label") { + test("StratifiedSampling on String, Int and Boolean label") { Logger.getRootLogger.setLevel(Level.WARN) val df = sqlContext.createDataFrame(Seq( - (0, "0"), - (0, "0"), - (1, "1"), - (1, "1") - )).toDF("label", "str") - val map = Map("1" -> 0.5, "2" -> 0.1) - val trans = new StratifiedSampler(false, map).setLabel("str") - trans.transform(df).schema == df.schema + (0, "0", false), + (0, "0", false), + (1, "1", true), + (1, "1", true) + )).toDF("int", "str", "bool") + val strMap = Map("1" -> 0.5, "0" -> 0.1) + assert(new StratifiedSampler(strMap).setLabel("str").transform(df).schema == df.schema) + + val intMap = Map(1 -> 0.5, 0 -> 0.1) + assert(new StratifiedSampler(intMap).setLabel("str").transform(df).schema == df.schema) + + val boolMap = Map(true -> 0.5, false -> 0.1) + assert(new StratifiedSampler(boolMap).setLabel("str").transform(df).schema == df.schema) } test("StratifiedSampling read/write") { - val t = new StratifiedSampler(false, Map("1" -> 0.5, "2" -> 0.1)) + val t = new StratifiedSampler(Map("1" -> 0.5, "2" -> 0.1)) .setLabel("myLabel") val newInstance = testDefaultReadWrite(t) - assert(t.withReplacement == newInstance.withReplacement && - t.fraction == newInstance.fraction) + assert(t.fractions == newInstance.fractions) } }