Skip to content

Commit

Permalink
create pr
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 5, 2019
1 parent 0f14949 commit a2955d8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 31 deletions.
Expand Up @@ -68,15 +68,13 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
transformSchema(dataset.schema, logging = true)
val Row(max: Vector, min: Vector) = dataset.select(
Summarizer.metrics("max", "min")
.summary(col($(inputCol)), lit(1.0)).as("summary"))
.select("summary.max", "summary.min").first()

val minVals = min.toArray
val maxVals = max.toArray
val n = minVals.length
val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) }

val Row(max: Vector, min: Vector) = dataset
.select(Summarizer.metrics("max", "min").summary(col($(inputCol))).as("summary"))
.select("summary.max", "summary.min")
.first()

val maxAbs = Array.tabulate(max.size) { i => math.max(math.abs(min(i)), math.abs(max(i))) }

copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs).compressed).setParent(this))
}
Expand Down
Expand Up @@ -24,12 +24,9 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
Expand Down Expand Up @@ -117,12 +114,13 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
transformSchema(dataset.schema, logging = true)
val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val summary = Statistics.colStats(input)
copyValues(new MinMaxScalerModel(uid, summary.min.compressed,
summary.max.compressed).setParent(this))

val Row(max: Vector, min: Vector) = dataset
.select(Summarizer.metrics("max", "min").summary(col($(inputCol))).as("summary"))
.select("summary.max", "summary.min")
.first()

copyValues(new MinMaxScalerModel(uid, min.compressed, max.compressed).setParent(this))
}

@Since("1.5.0")
Expand Down
Expand Up @@ -24,10 +24,8 @@ import org.apache.spark.ml._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature.{StandardScaler => OldStandardScaler}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -109,13 +107,15 @@ class StandardScaler @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val scaler = new OldStandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
copyValues(new StandardScalerModel(uid, scalerModel.std.compressed,
scalerModel.mean.compressed).setParent(this))

val Row(mean: Vector, variance: Vector) = dataset
.select(Summarizer.metrics("mean", "variance").summary(col($(inputCol))).as("summary"))
.select("summary.mean", "summary.variance")
.first()

val std = Vectors.dense(variance.toArray.map(math.sqrt))

copyValues(new StandardScalerModel(uid, std.compressed, mean.compressed).setParent(this))
}

@Since("1.4.0")
Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -65,7 +65,7 @@ sealed abstract class SummaryBuilder {
* val dataframe = ... // Some dataframe containing a feature column and a weight column
* val multiStatsDF = dataframe.select(
* Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
* val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
* val Row(minVec, maxVec, count) = multiStatsDF.first()
* }}}
*
* If one wants to get a single metric, shortcuts are also available:
Expand Down

0 comments on commit a2955d8

Please sign in to comment.