Skip to content

Commit

Permalink
[SPARK-7446] [MLLIB] Add inverse transform for string indexer
Browse files Browse the repository at this point in the history
It is useful to convert the encoded indices back to their string representation for result inspection. We can add a function which creates an inverse transformation.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #6339 from holdenk/SPARK-7446-inverse-transform-for-string-indexer and squashes the following commits:

7cdf915 [Holden Karau] scala style comment fix
b9cffb6 [Holden Karau] Update the labels param to have the metadata note
6a38edb [Holden Karau] Setting the default needs to come after the value gets defined
9e241d8 [Holden Karau] use Array.empty
21c8cfa [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer
64dd3a3 [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer
4f06c59 [Holden Karau] Fix comment styles, use empty array as the default, etc.
a60c0e3 [Holden Karau] CR feedback (remove old constructor, add a note about use of setLabels)
1987b95 [Holden Karau] Use default copy
71e8d66 [Holden Karau] Make labels a local param for StringIndexerInverse
8450d0b [Holden Karau] Use the labels param in StringIndexerInverse
7464019 [Holden Karau] Add a labels param
868b1a9 [Holden Karau] Update scaladoc since we don't have labelsCol anymore
5aa38bf [Holden Karau] Add an inverse test using only meta data, pass labels when calling inverse method
f3e0c64 [Holden Karau] CR feedback
ebed932 [Holden Karau] Add Experimental tag and some scaladocs. Also don't require that the inputCol has the metadata on it, instead have the labelsCol specified when creating the inverse.
03ebf95 [Holden Karau] Add explicit type for invert function
ecc65e0 [Holden Karau] Read the metadata correctly, use the array, pass the test
a42d773 [Holden Karau] Fix test to supply cols as per new invert method
16cc3c3 [Holden Karau] Add an invert method
d4bcb20 [Holden Karau] Make the inverse string indexer into a transformer (still needs test updates but compiles)
e8bf3ad [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer
c3fdee1 [Holden Karau] Some WIP refactoring based on jkbradley's CR feedback. Definite work-in-progress
557bef8 [Holden Karau] Instead of using a private inverse transform, add an invert function so we can use it in a pipeline
88779c1 [Holden Karau] fix long line
78b28c1 [Holden Karau] Finish reverse part and add a test :)
bb16a6a [Holden Karau] Some progress
  • Loading branch information
holdenk authored and jkbradley committed Aug 1, 2015
1 parent 60ea7ab commit 6503897
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
108 changes: 105 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -151,4 +152,105 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra)
}

/**
* Return a model to perform the inverse transformation.
* Note: By default we keep the original columns during this transformation, so the inverse
* should only be used on new columns such as predicted labels.
*/
def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
new StringIndexerInverse()
.setInputCol(inputCol)
.setOutputCol(outputCol)
.setLabels(labels)
}
}

/**
* :: Experimental ::
* Transform a provided column back to the original input types using either the metadata
* on the input column, or if provided using the labels supplied by the user.
* Note: By default we keep the original columns during this transformation,
* so the inverse should only be used on new columns such as predicted labels.
*/
@Experimental
class StringIndexerInverse private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {

def this() =
this(Identifiable.randomUID("strIdxInv"))

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels. The default value is an empty array,
* but the empty array is ignored and column metadata used instead.
* @group setParam
*/
def setLabels(value: Array[String]): this.type = set(labels, value)

/**
* Param for array of labels.
* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels.
* @group param
*/
final val labels: StringArrayParam = new StringArrayParam(this, "labels",
"array of labels, if not provided metadata from inputCol is used instead.")
setDefault(labels, Array.empty[String])

/**
* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels.
* @group getParam
*/
final def getLabels: Array[String] = $(labels)

/** Transform the schema for the inverse transformation */
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be a numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

override def transform(dataset: DataFrame): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if ($(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute].values.get
} else {
$(labels)
}
val indexer = udf { index: Double =>
val idx = index.toInt
if (0 <= idx && idx < values.size) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
}
}
val outputColName = $(outputCol)
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}

override def copy(extra: ParamMap): StringIndexerInverse = {
defaultCopy(extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
// convert reverse our transform
val reversed = indexer.invert("labelIndex", "label2")
.transform(transformed)
.select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
// Check invert using only metadata
val inverse2 = new StringIndexerInverse()
.setInputCol("labelIndex")
.setOutputCol("label2")
val reversed2 = inverse2.transform(transformed).select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}

test("StringIndexer with a numeric input column") {
Expand Down

0 comments on commit 6503897

Please sign in to comment.