Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark-7446][MLLIB] Add inverse transform for string indexer #6339

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bb16a6a
Some progress
holdenk May 21, 2015
78b28c1
Finish reverse part and add a test :)
holdenk May 22, 2015
88779c1
fix long line
holdenk May 22, 2015
557bef8
Instead of using a private inverse transform, add an invert function …
holdenk Jul 8, 2015
c3fdee1
Some WIP refactoring based on jkbradley's CR feedback. Definite work-…
holdenk Jul 28, 2015
e8bf3ad
Merge branch 'master' into SPARK-7446-inverse-transform-for-string-in…
holdenk Jul 28, 2015
d4bcb20
Make the inverse string indexer into a transformer (still needs test …
holdenk Jul 29, 2015
16cc3c3
Add an invert method
holdenk Jul 29, 2015
a42d773
Fix test to supply cols as per new invert method
holdenk Jul 29, 2015
ecc65e0
Read the metadata correctly, use the array, pass the test
holdenk Jul 29, 2015
03ebf95
Add explicit type for invert function
holdenk Jul 29, 2015
ebed932
Add Experimental tag and some scaladocs. Also don't require that the …
holdenk Jul 29, 2015
f3e0c64
CR feedback
holdenk Jul 29, 2015
5aa38bf
Add an inverse test using only meta data, pass labels when calling in…
holdenk Jul 29, 2015
868b1a9
Update scaladoc since we don't have labelsCol anymore
holdenk Jul 29, 2015
7464019
Add a labels param
holdenk Jul 29, 2015
8450d0b
Use the labels param in StringIndexerInverse
holdenk Jul 29, 2015
71e8d66
Make labels a local param for StringIndexerInverse
holdenk Jul 30, 2015
1987b95
Use default copy
holdenk Jul 30, 2015
a60c0e3
CR feedback (remove old constructor, add a note about use of setLabels)
holdenk Jul 30, 2015
4f06c59
Fix comment styles, use empty array as the default, etc.
holdenk Jul 31, 2015
64dd3a3
Merge branch 'master' into SPARK-7446-inverse-transform-for-string-in…
holdenk Jul 31, 2015
21c8cfa
Merge branch 'master' into SPARK-7446-inverse-transform-for-string-in…
holdenk Jul 31, 2015
9e241d8
use Array.empty
holdenk Jul 31, 2015
6a38edb
Setting the default needs to come after the value gets defined
holdenk Jul 31, 2015
b9cffb6
Update the labels param to have the metadata note
holdenk Jul 31, 2015
7cdf915
scala style comment fix
holdenk Jul 31, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -151,4 +152,105 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra)
}

/**
* Return a model to perform the inverse transformation.
* Note: By default we keep the original columns during this transformation, so the inverse
* should only be used on new columns such as predicted labels.
*/
def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
new StringIndexerInverse()
.setInputCol(inputCol)
.setOutputCol(outputCol)
.setLabels(labels)
}
}

/**
* :: Experimental ::
* Transform a provided column back to the original input types using either the metadata
* on the input column, or if provided using the labels supplied by the user.
* Note: By default we keep the original columns during this transformation,
* so the inverse should only be used on new columns such as predicted labels.
*/
@Experimental
class StringIndexerInverse private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {

def this() =
this(Identifiable.randomUID("strIdxInv"))

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels. The default value is an empty array,
* but the empty array is ignored and column metadata used instead.
* @group setParam
*/
def setLabels(value: Array[String]): this.type = set(labels, value)

/**
* Param for array of labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy doc here explaining that, if labels are not given, they will be taken from the inputCol metadata.

Same for built-in Param doc.

* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels.
* @group param
*/
final val labels: StringArrayParam = new StringArrayParam(this, "labels",
"array of labels, if not provided metadata from inputCol is used instead.")
setDefault(labels, Array.empty[String])

/**
* Optional labels to be provided by the user, if not supplied column
* metadata is read for labels.
* @group getParam
*/
final def getLabels: Array[String] = $(labels)

/** Transform the schema for the inverse transformation */
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be a numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

override def transform(dataset: DataFrame): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if ($(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute].values.get
} else {
$(labels)
}
val indexer = udf { index: Double =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all NumericTypes be cast to Double like this? Either test this in a unit test, or just switch to supporting DoubleType only for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't ever explicitly ask for it as a NumericType, is there something I'm missing here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UDF method take a Double, and I'm wondering if it will work if the input data type is something odd like FractionalType. But now that you have the cast to DoubleType, I bet it's fine.

val idx = index.toInt
if (0 <= idx && idx < values.size) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
}
}
val outputColName = $(outputCol)
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}

override def copy(extra: ParamMap): StringIndexerInverse = {
defaultCopy(extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
// convert reverse our transform
val reversed = indexer.invert("labelIndex", "label2")
.transform(transformed)
.select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to check more carefully by zipping the original label column with the newly created one, and checking for equality.

reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
// Check invert using only metadata
val inverse2 = new StringIndexerInverse()
.setInputCol("labelIndex")
.setOutputCol("label2")
val reversed2 = inverse2.transform(transformed).select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}

test("StringIndexer with a numeric input column") {
Expand Down