Skip to content

Commit 83302c3

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-13036][SPARK-13318][SPARK-13319] Add save/load for feature.py
Add save/load for feature.py. Meanwhile, add save/load for `ElementwiseProduct` in Scala side and fix a bug of missing `setDefault` in `VectorSlicer` and `StopWordsRemover`. In this PR I ignore the `RFormula` and `RFormulaModel` because its Scala implementation is pending in #9884. I'll add them in this PR if #9884 gets merged first. Or add a follow-up JIRA for `RFormula`. Author: Xusen Yin <yinxusen@gmail.com> Closes #11203 from yinxusen/SPARK-13036.
1 parent c8f2545 commit 83302c3

File tree

3 files changed

+341
-48
lines changed

3 files changed

+341
-48
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param.Param
23-
import org.apache.spark.ml.util.Identifiable
23+
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
2424
import org.apache.spark.mllib.feature
2525
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2626
import org.apache.spark.sql.types.DataType
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DataType
3333
*/
3434
@Experimental
3535
class ElementwiseProduct(override val uid: String)
36-
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
36+
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] with DefaultParamsWritable {
3737

3838
def this() = this(Identifiable.randomUID("elemProd"))
3939

@@ -57,3 +57,10 @@ class ElementwiseProduct(override val uid: String)
5757

5858
override protected def outputDataType: DataType = new VectorUDT()
5959
}
60+
61+
@Since("2.0.0")
62+
object ElementwiseProduct extends DefaultParamsReadable[ElementwiseProduct] {
63+
64+
@Since("2.0.0")
65+
override def load(path: String): ElementwiseProduct = super.load(path)
66+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.util.DefaultReadWriteTest
22+
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
25+
class ElementwiseProductSuite
26+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
27+
28+
test("read/write") {
29+
val ep = new ElementwiseProduct()
30+
.setInputCol("myInputCol")
31+
.setOutputCol("myOutputCol")
32+
.setScalingVec(Vectors.dense(0.1, 0.2))
33+
testDefaultReadWrite(ep)
34+
}
35+
}

0 commit comments

Comments
 (0)