Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8664] [ML] Add PCA transformer #7065

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 130 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}

/**
* Params for [[PCA]] and [[PCAModel]].
*/
private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol {

/**
* The number of principal components.
* @group param
*/
final val k: IntParam = new IntParam(this, "k", "the number of principal components")

/** @group getParam */
def getK: Int = $(k)

}

/**
* :: Experimental ::
* PCA trains a model to project vectors to a low-dimensional space using PCA.
*/
@Experimental
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {

def this() = this(Identifiable.randomUID("pca"))

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
def setK(value: Int): this.type = set(k, value)

/**
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
*/
override def fit(dataset: DataFrame): PCAModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
val pcaModel = pca.fit(input)
copyValues(new PCAModel(uid, pcaModel).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
require(!schema.fieldNames.contains($(outputCol)),
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}

override def copy(extra: ParamMap): PCA = defaultCopy(extra)
}

/**
* :: Experimental ::
* Model fitted by [[PCA]].
*/
@Experimental
class PCAModel private[ml] (
override val uid: String,
pcaModel: feature.PCAModel)
extends Model[PCAModel] with PCAParams {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Transform a vector by computed Principal Components.
* NOTE: Vectors to be transformed must be the same length
* as the source vectors given to [[PCA.fit()]].
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val pcaOp = udf { pcaModel.transform _ }
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
}

override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
require(!schema.fieldNames.contains($(outputCol)),
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}

override def copy(extra: ParamMap): PCAModel = {
val copied = new PCAModel(uid, pcaModel)
copyValues(copied, extra)
}
}
Expand Up @@ -68,7 +68,7 @@ class PCA(val k: Int) {
* @param k number of principal components.
* @param pc a principal components Matrix. Each column is one principal component.
*/
class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because test case of ml.feature.PCASuite(https://github.com/apache/spark/pull/7065/files#diff-e1593bb9e311c3f2a2ea49cce20ed671R34) use the constructor, so I change it to spark private like Word2VecModel.
There are different access permission of constructors in mllib.feature, some are private[spark] while others are public. I think it's confusion and need to uniform in a separate task.

/**
* Transform a vector by computed Principal Components.
*
Expand Down
64 changes: 64 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
import org.apache.spark.sql.Row

class PCASuite extends SparkFunSuite with MLlibTestSparkContext {

test("params") {
ParamsSuite.checkParams(new PCA)
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
val model = new PCAModel("pca", new OldPCAModel(2, mat))
ParamsSuite.checkParams(model)
}

test("pca") {
val data = Array(
Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
)

val dataRDD = sc.parallelize(data, 2)

val mat = new RowMatrix(dataRDD)
val pc = mat.computePrincipalComponents(3)
val expected = mat.multiply(pc).rows

val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")

val pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
.setK(3)
.fit(df)

pca.transform(df).select("pca_features", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
}