Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.{AttributeGroup, _}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params
*/
@Experimental
final class ChiSqSelector(override val uid: String)
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("chiSqSelector"))

Expand Down Expand Up @@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String)
override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
}

@Since("1.6.0")
object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {

@Since("1.6.0")
override def load(path: String): ChiSqSelector = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[ChiSqSelector]].
Expand All @@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String)
final class ChiSqSelectorModel private[ml] (
override val uid: String,
private val chiSqSelector: feature.ChiSqSelectorModel)
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {

import ChiSqSelectorModel._

/** list of indices to select (filter). Must be ordered asc */
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures

/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down Expand Up @@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] (
val copied = new ChiSqSelectorModel(uid, chiSqSelector)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new ChiSqSelectorModelWriter(this)
}

@Since("1.6.0")
object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {

private[ChiSqSelectorModel]
class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter {

private case class Data(selectedFeatures: Seq[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.selectedFeatures.toSeq)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {

private val className = classOf[ChiSqSelectorModel].getName

override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader

@Since("1.6.0")
override def load(path: String): ChiSqSelectorModel = super.load(path)
}
67 changes: 62 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
Expand All @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
* PCA trains a model to project vectors to a low-dimensional space using PCA.
*/
@Experimental
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("pca"))

Expand Down Expand Up @@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
}

@Since("1.6.0")
object PCA extends DefaultParamsReadable[PCA] {

@Since("1.6.0")
override def load(path: String): PCA = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[PCA]].
Expand All @@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
class PCAModel private[ml] (
override val uid: String,
pcaModel: feature.PCAModel)
extends Model[PCAModel] with PCAParams {
extends Model[PCAModel] with PCAParams with MLWritable {

import PCAModel._

/** a principal components Matrix. Each column is one principal component. */
val pc: DenseMatrix = pcaModel.pc

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand Down Expand Up @@ -127,4 +142,46 @@ class PCAModel private[ml] (
val copied = new PCAModel(uid, pcaModel)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new PCAModelWriter(this)
}

@Since("1.6.0")
object PCAModel extends MLReadable[PCAModel] {

private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {

private case class Data(k: Int, pc: DenseMatrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.getK, instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class PCAModelReader extends MLReader[PCAModel] {

private val className = classOf[PCAModel].getName

override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
.select("k", "pc")
.head()
val oldModel = new feature.PCAModel(k, pc)
val model = new PCAModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[PCAModel] = new PCAModelReader

@Since("1.6.0")
override def load(path: String): PCAModel = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import java.util.{Map => JMap}

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.udf
Expand Down Expand Up @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
*/
@Experimental
class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
with VectorIndexerParams {
with VectorIndexerParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("vecIdx"))

Expand Down Expand Up @@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}

private object VectorIndexer {
@Since("1.6.0")
object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {

@Since("1.6.0")
override def load(path: String): VectorIndexer = super.load(path)

/**
* Helper class for tracking unique values for each feature.
Expand All @@ -146,7 +152,7 @@ private object VectorIndexer {
* @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures.
* @param maxCategories This class caps the number of unique values collected at maxCategories.
*/
class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
extends Serializable {

/** featureValueSets[feature index] = set of unique values */
Expand Down Expand Up @@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] (
override val uid: String,
val numFeatures: Int,
val categoryMaps: Map[Int, Map[Double, Int]])
extends Model[VectorIndexerModel] with VectorIndexerParams {
extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable {

import VectorIndexerModel._

/** Java-friendly version of [[categoryMaps]] */
def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
Expand Down Expand Up @@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] (
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new VectorIndexerModelWriter(this)
}

@Since("1.6.0")
object VectorIndexerModel extends MLReadable[VectorIndexerModel] {

private[VectorIndexerModel]
class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter {

private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.numFeatures, instance.categoryMaps)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] {

private val className = classOf[VectorIndexerModel].getName

override def load(path: String): VectorIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("numFeatures", "categoryMaps")
.head()
val numFeatures = data.getAs[Int](0)
val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader

@Since("1.6.0")
override def load(path: String): VectorIndexerModel = super.load(path)
}
Loading