Skip to content

Commit

Permalink
Infer label names automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed May 5, 2015
1 parent 6e257b9 commit f383250
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
Expand Down Expand Up @@ -47,45 +48,42 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
new BooleanParam(this, "includeFirst", "include first category")
setDefault(includeFirst -> true)

/**
* The names of the categories. Used to identify them in the attributes of the output column.
* This is a required parameter.
* @group param
*/
final val labelNames: Param[Array[String]] =
new Param[Array[String]](this, "labelNames", "categorical label names")
private var categories: Array[String] = _

/** @group setParam */
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)

/** @group setParam */
def setLabelNames(attr: NominalAttribute): this.type = set(labelNames, attr.values.get)

/** @group setParam */
override def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
override def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
require(map.contains(labelNames), "OneHotEncoder missing category names")
val categories = map(labelNames)
val attrValues = (if (map(includeFirst)) categories else categories.drop(1)).toArray
val outputColName = $(outputCol)
require(inputFields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")

val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
categories = inputColAttr match {
case nominal: NominalAttribute =>
nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
case _ =>
throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
}

val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

protected def createTransformFunc(paramMap: ParamMap): (Double) => Vector = {
val map = extractParamMap(paramMap)
val first = map(includeFirst)
val vecLen = if (first) map(labelNames).length else map(labelNames).length - 1
protected override def createTransformFunc(): (Double) => Vector = {
val first = $(includeFirst)
val vecLen = if (first) categories.length else categories.length - 1
val oneValue = Array(1.0)
val emptyValues = Array[Double]()
val emptyIndices = Array[Int]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext

import org.apache.spark.sql.{DataFrame, SQLContext}

import org.scalatest.FunSuite

class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _
Expand All @@ -33,23 +32,19 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
sqlContext = new SQLContext(sc)
}

def stringIndexed(): (DataFrame, NominalAttribute) = {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
(transformed, attr)
indexer.transform(df)
}

test("OneHotEncoder includeFirst = true") {
val (transformed, attr) = stringIndexed()
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setLabelNames(attr)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
Expand All @@ -65,10 +60,9 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
}

test("OneHotEncoder includeFirst = false") {
val (transformed, attr) = stringIndexed()
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setIncludeFirst(false)
.setLabelNames(attr)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
Expand Down

0 comments on commit f383250

Please sign in to comment.