Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {

/**
* Centers the data with mean before scaling.
* Whether to center the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
val withMean: BooleanParam = new BooleanParam(this, "withMean",
"Whether to center data with mean")

/** @group getParam */
def getWithMean: Boolean = $(withMean)

/**
* Scales the data to unit standard deviation.
* Whether to scale the data to unit standard deviation.
* Default: true
* @group param
*/
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
val withStd: BooleanParam = new BooleanParam(this, "withStd",
"Whether to scale the data to unit standard deviation")

/** @group getParam */
def getWithStd: Boolean = $(withStd)

setDefault(withMean -> false, withStd -> true)
}

/**
Expand All @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM

def this() = this(Identifiable.randomUID("stdScal"))

setDefault(withMean -> false, withStd -> true)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
/**
* :: Experimental ::
* Model fitted by [[StandardScaler]].
*
* @param std Standard deviation of the StandardScalerModel
* @param mean Mean of the StandardScalerModel
*/
@Experimental
class StandardScalerModel private[ml] (
override val uid: String,
scaler: feature.StandardScalerModel)
val std: Vector,
val mean: Vector)
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {

import StandardScalerModel._

/** Standard deviation of the StandardScalerModel */
val std: Vector = scaler.std

/** Mean of the StandardScalerModel */
val mean: Vector = scaler.mean

/** Whether to scale to unit standard deviation. */
@Since("1.6.0")
def getWithStd: Boolean = scaler.withStd

/** Whether to center data with mean. */
@Since("1.6.0")
def getWithMean: Boolean = scaler.withMean

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
Expand All @@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
}

override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
val copied = new StandardScalerModel(uid, std, mean)
copyValues(copied, extra).setParent(parent)
}

Expand All @@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
private[StandardScalerModel]
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {

private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
private case class Data(std: Vector, mean: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
Expand All @@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
sqlContext.read.parquet(dataPath)
.select("std", "mean", "withStd", "withMean")
.head()
// This is very likely to change in the future because withStd and withMean should be params.
val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
val model = new StandardScalerModel(metadata.uid, oldModel)
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
DefaultParamsReader.getAndSetParams(model, metadata)
model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If withMean and withStd are parameters, we should save them in metadata/ but not both under data/ and medadata/. Can we change the constructor of ml.StandardScalerModel to take only std and mean but construct scaler only inside transform? So scaler is no longer a member variable. We can fix performance issues in 1.7.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext

test("params") {
ParamsSuite.checkParams(new StandardScaler)
val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
ParamsSuite.checkParams(new StandardScalerModel("empty",
Vectors.dense(1.0), Vectors.dense(2.0)))
}

test("Standardization with default parameter") {
Expand Down Expand Up @@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}

test("StandardScalerModel read/write") {
val oldModel = new feature.StandardScalerModel(
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
val instance = new StandardScalerModel("myStandardScalerModel",
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.std === instance.std)
assert(newInstance.mean === instance.mean)
assert(newInstance.getWithStd === instance.getWithStd)
assert(newInstance.getWithMean === instance.getWithMean)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withStd and withStd of StandardScalerModel must be inherited from StandardScaler, so we can not construct StandardScalerModel directly by specifying the two variables. Here we combine the original test cases into one with testEstimatorAndModelReadWrite which both test the estimator and model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not an ideal unit test for read/write because the model fitting part shouldn't be part of it, which is already covered by other tests. Constructing estimator and model directly can save some test time.

}