Skip to content

Commit

Permalink
SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed May 4, 2015
1 parent f32e69e commit 1c182dd
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}

@AlphaComponent
class OneHotEncoder(labelNames: Seq[String], includeFirst: Boolean = true) extends Transformer
with HasInputCol {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

private def outputColName(index: Int): String = {
s"${get(inputCol)}_${labelNames(index)}"
}

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
val map = this.paramMap ++ paramMap

val startIndex = if (includeFirst) 0 else 1
val cols = (startIndex until labelNames.length).map { index =>
val colEncoder = udf { label: Double => if (index == label) 1.0 else 0.0 }
colEncoder(dataset(map(inputCol))).as(outputColName(index))
}

dataset.select(Array(col("*")) ++ cols: _*)
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
checkInputColumn(schema, map(inputCol), StringType)
val inputFields = schema.fields
val startIndex = if (includeFirst) 0 else 1
val fields = (startIndex until labelNames.length).map { index =>
val colName = outputColName(index)
require(inputFields.forall(_.name != colName),
s"Output column $colName already exists.")
NominalAttribute.defaultAttr.withName(colName).toStructField()
}

val outputFields = inputFields ++ fields
StructType(outputFields)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.mllib.util.MLlibTestSparkContext

import org.scalatest.FunSuite
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}

class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("OneHotEncoder") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))

val encoder = new OneHotEncoder(attr.values.get)
.setInputCol("labelIndex")
val encoded = encoder.transform(transformed)

val output = encoded.select("id", "labelIndex_a", "labelIndex_c", "labelIndex_b").map { r =>
(r.getInt(0), r.getDouble(1), r.getDouble(2), r.getDouble(3))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
assert(output === expected)
}

}

0 comments on commit 1c182dd

Please sign in to comment.