Skip to content

Commit

Permalink
create pr
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 5, 2019
1 parent ef1abf2 commit 0f14949
Showing 1 changed file with 8 additions and 8 deletions.
Expand Up @@ -24,9 +24,8 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
Expand Down Expand Up @@ -69,12 +68,13 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val summary = Statistics.colStats(input)
val minVals = summary.min.toArray
val maxVals = summary.max.toArray
val Row(max: Vector, min: Vector) = dataset.select(
Summarizer.metrics("max", "min")
.summary(col($(inputCol)), lit(1.0)).as("summary"))
.select("summary.max", "summary.min").first()

val minVals = min.toArray
val maxVals = max.toArray
val n = minVals.length
val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) }

Expand Down

0 comments on commit 0f14949

Please sign in to comment.