Skip to content

Commit

Permalink
Made std dev optional in standard scaler.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerk committed May 8, 2015
1 parent 5ffb84c commit d28ff3d
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions src/main/scala/nodes/misc/StandardScaler.scala
Expand Up @@ -13,7 +13,7 @@ import utils.MLlibUtils
* @param mean column mean values
* @param std column standard deviation values
*/
class StandardScalerModel(val mean: DenseVector[Double], val std: DenseVector[Double])
class StandardScalerModel(val mean: DenseVector[Double], val std: Option[DenseVector[Double]] = None)
extends Transformer[DenseVector[Double], DenseVector[Double]] {
/**
* Applies standardization transformation on a vector.
Expand All @@ -23,22 +23,19 @@ class StandardScalerModel(val mean: DenseVector[Double], val std: DenseVector[Do
* for the column with zero std.
*/
override def apply(in: DenseVector[Double]): DenseVector[Double] = {
val values = in.copy
val size = values.length
var i = 0
while (i < size) {
values(i) = if (std(i) != 0.0) (values(i) - mean(i)) * (1.0 / std(i)) else 0.0
i += 1
}
values
val out = in - mean
std.foreach(x => {
out :/= x
})
out
}
}

/**
* Standardizes features by removing the mean and scaling to unit std using column summary
* statistics on the samples in the training set.
*/
class StandardScaler(eps: Double = 1E-12) extends Estimator[DenseVector[Double], DenseVector[Double]]{
class StandardScaler(normalizeStdDev: Boolean = true, eps: Double = 1E-12) extends Estimator[DenseVector[Double], DenseVector[Double]]{
/**
* Computes the mean and variance and stores as a model to be used for later scaling.
*
Expand All @@ -49,9 +46,15 @@ class StandardScaler(eps: Double = 1E-12) extends Estimator[DenseVector[Double],
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(MLlibUtils.breezeVectorToMLlib(data)),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
new StandardScalerModel(
MLlibUtils.mllibVectorToDenseBreeze(summary.mean),
sqrt(MLlibUtils.mllibVectorToDenseBreeze(summary.variance))
.map(r => if(r.isNaN | r.isInfinite | math.abs(r) < eps) 1.0 else r))
if (normalizeStdDev) {
new StandardScalerModel(
MLlibUtils.mllibVectorToDenseBreeze(summary.mean),
Some(sqrt(MLlibUtils.mllibVectorToDenseBreeze(summary.variance))
.map(r => if (r.isNaN | r.isInfinite | math.abs(r) < eps) 1.0 else r)))
} else {
new StandardScalerModel(
MLlibUtils.mllibVectorToDenseBreeze(summary.mean),
None)
}
}
}

0 comments on commit d28ff3d

Please sign in to comment.