From d28ff3da143721d019f42f80b94a4234f14cdc76 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Thu, 7 May 2015 19:13:24 -0700 Subject: [PATCH] Made std dev optional in standard scaler. --- .../scala/nodes/misc/StandardScaler.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/main/scala/nodes/misc/StandardScaler.scala b/src/main/scala/nodes/misc/StandardScaler.scala index f9f34ec1..bd98c027 100644 --- a/src/main/scala/nodes/misc/StandardScaler.scala +++ b/src/main/scala/nodes/misc/StandardScaler.scala @@ -13,7 +13,7 @@ import utils.MLlibUtils * @param mean column mean values * @param std column standard deviation values */ -class StandardScalerModel(val mean: DenseVector[Double], val std: DenseVector[Double]) +class StandardScalerModel(val mean: DenseVector[Double], val std: Option[DenseVector[Double]] = None) extends Transformer[DenseVector[Double], DenseVector[Double]] { /** * Applies standardization transformation on a vector. @@ -23,14 +23,11 @@ class StandardScalerModel(val mean: DenseVector[Double], val std: DenseVector[Do * for the column with zero std. */ override def apply(in: DenseVector[Double]): DenseVector[Double] = { - val values = in.copy - val size = values.length - var i = 0 - while (i < size) { - values(i) = if (std(i) != 0.0) (values(i) - mean(i)) * (1.0 / std(i)) else 0.0 - i += 1 - } - values + val out = in - mean + std.foreach(x => { + out :/= x + }) + out } } @@ -38,7 +35,7 @@ class StandardScalerModel(val mean: DenseVector[Double], val std: DenseVector[Do * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. */ -class StandardScaler(eps: Double = 1E-12) extends Estimator[DenseVector[Double], DenseVector[Double]]{ +class StandardScaler(normalizeStdDev: Boolean = true, eps: Double = 1E-12) extends Estimator[DenseVector[Double], DenseVector[Double]]{ /** * Computes the mean and variance and stores as a model to be used for later scaling. * @@ -49,9 +46,15 @@ class StandardScaler(eps: Double = 1E-12) extends Estimator[DenseVector[Double], val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(MLlibUtils.breezeVectorToMLlib(data)), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) - new StandardScalerModel( - MLlibUtils.mllibVectorToDenseBreeze(summary.mean), - sqrt(MLlibUtils.mllibVectorToDenseBreeze(summary.variance)) - .map(r => if(r.isNaN | r.isInfinite | math.abs(r) < eps) 1.0 else r)) + if (normalizeStdDev) { + new StandardScalerModel( + MLlibUtils.mllibVectorToDenseBreeze(summary.mean), + Some(sqrt(MLlibUtils.mllibVectorToDenseBreeze(summary.variance)) + .map(r => if (r.isNaN | r.isInfinite | math.abs(r) < eps) 1.0 else r))) + } else { + new StandardScalerModel( + MLlibUtils.mllibVectorToDenseBreeze(summary.mean), + None) + } } }