Skip to content

Commit

Permalink
[SPARK-13036][SPARK-13318][SPARK-13319] Add save/load for feature.py
Browse files Browse the repository at this point in the history
Add save/load for feature.py. Meanwhile, add save/load for `ElementwiseProduct` in Scala side and fix a bug of missing `setDefault` in `VectorSlicer` and `StopWordsRemover`.

In this PR I ignore the `RFormula` and `RFormulaModel` because its Scala implementation is pending in #9884. I'll add them in this PR if #9884 gets merged first. Or add a follow-up JIRA for `RFormula`.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11203 from yinxusen/SPARK-13036.
  • Loading branch information
yinxusen authored and mengxr committed Mar 4, 2016
1 parent c8f2545 commit 83302c3
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 48 deletions.
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class ElementwiseProduct(override val uid: String)
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("elemProd"))

Expand All @@ -57,3 +57,10 @@ class ElementwiseProduct(override val uid: String)

override protected def outputDataType: DataType = new VectorUDT()
}

@Since("2.0.0")
object ElementwiseProduct extends DefaultParamsReadable[ElementwiseProduct] {

@Since("2.0.0")
override def load(path: String): ElementwiseProduct = super.load(path)
}
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext

class ElementwiseProductSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

test("read/write") {
val ep = new ElementwiseProduct()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setScalingVec(Vectors.dense(0.1, 0.2))
testDefaultReadWrite(ep)
}
}

0 comments on commit 83302c3

Please sign in to comment.