From 65038973a17904e0e04d453799ec108af240fbab Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 1 Aug 2015 01:09:38 -0700 Subject: [PATCH] [SPARK-7446] [MLLIB] Add inverse transform for string indexer It is useful to convert the encoded indices back to their string representation for result inspection. We can add a function which creates an inverse transformation. Author: Holden Karau Closes #6339 from holdenk/SPARK-7446-inverse-transform-for-string-indexer and squashes the following commits: 7cdf915 [Holden Karau] scala style comment fix b9cffb6 [Holden Karau] Update the labels param to have the metadata note 6a38edb [Holden Karau] Setting the default needs to come after the value gets defined 9e241d8 [Holden Karau] use Array.empty 21c8cfa [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer 64dd3a3 [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer 4f06c59 [Holden Karau] Fix comment styles, use empty array as the default, etc. a60c0e3 [Holden Karau] CR feedback (remove old constructor, add a note about use of setLabels) 1987b95 [Holden Karau] Use default copy 71e8d66 [Holden Karau] Make labels a local param for StringIndexerInverse 8450d0b [Holden Karau] Use the labels param in StringIndexerInverse 7464019 [Holden Karau] Add a labels param 868b1a9 [Holden Karau] Update scaladoc since we don't have labelsCol anymore 5aa38bf [Holden Karau] Add an inverse test using only meta data, pass labels when calling inverse method f3e0c64 [Holden Karau] CR feedback ebed932 [Holden Karau] Add Experimental tag and some scaladocs. Also don't require that the inputCol has the metadata on it, instead have the labelsCol specified when creating the inverse. 03ebf95 [Holden Karau] Add explicit type for invert function ecc65e0 [Holden Karau] Read the metadata correctly, use the array, pass the test a42d773 [Holden Karau] Fix test to supply cols as per new invert method 16cc3c3 [Holden Karau] Add an invert method d4bcb20 [Holden Karau] Make the inverse string indexer into a transformer (still needs test updates but compiles) e8bf3ad [Holden Karau] Merge branch 'master' into SPARK-7446-inverse-transform-for-string-indexer c3fdee1 [Holden Karau] Some WIP refactoring based on jkbradley's CR feedback. Definite work-in-progress 557bef8 [Holden Karau] Instead of using a private inverse transform, add an invert function so we can use it in a pipeline 88779c1 [Holden Karau] fix long line 78b28c1 [Holden Karau] Finish reverse part and add a test :) bb16a6a [Holden Karau] Some progress --- .../spark/ml/feature/StringIndexer.scala | 108 +++++++++++++++++- .../spark/ml/feature/StringIndexerSuite.scala | 13 +++ 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index bf7be363b8224..ebfa972532358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{NumericType, StringType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -151,4 +152,105 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra) } + + /** + * Return a model to perform the inverse transformation. + * Note: By default we keep the original columns during this transformation, so the inverse + * should only be used on new columns such as predicted labels. + */ + def invert(inputCol: String, outputCol: String): StringIndexerInverse = { + new StringIndexerInverse() + .setInputCol(inputCol) + .setOutputCol(outputCol) + .setLabels(labels) + } +} + +/** + * :: Experimental :: + * Transform a provided column back to the original input types using either the metadata + * on the input column, or if provided using the labels supplied by the user. + * Note: By default we keep the original columns during this transformation, + * so the inverse should only be used on new columns such as predicted labels. + */ +@Experimental +class StringIndexerInverse private[ml] ( + override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { + + def this() = + this(Identifiable.randomUID("strIdxInv")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. The default value is an empty array, + * but the empty array is ignored and column metadata used instead. + * @group setParam + */ + def setLabels(value: Array[String]): this.type = set(labels, value) + + /** + * Param for array of labels. + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group param + */ + final val labels: StringArrayParam = new StringArrayParam(this, "labels", + "array of labels, if not provided metadata from inputCol is used instead.") + setDefault(labels, Array.empty[String]) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group getParam + */ + final def getLabels: Array[String] = $(labels) + + /** Transform the schema for the inverse transformation */ + override def transformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be a numeric type, " + + s"but got $inputDataType.") + val inputFields = schema.fields + val outputColName = $(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + override def transform(dataset: DataFrame): DataFrame = { + val inputColSchema = dataset.schema($(inputCol)) + // If the labels array is empty use column metadata + val values = if ($(labels).isEmpty) { + Attribute.fromStructField(inputColSchema) + .asInstanceOf[NominalAttribute].values.get + } else { + $(labels) + } + val indexer = udf { index: Double => + val idx = index.toInt + if (0 <= idx && idx < values.size) { + values(idx) + } else { + throw new SparkException(s"Unseen index: $index ??") + } + } + val outputColName = $(outputCol) + dataset.select(col("*"), + indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) + } + + override def copy(extra: ParamMap): StringIndexerInverse = { + defaultCopy(extra) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f82bea42688..d0295a0fe2fc1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) + // convert reverse our transform + val reversed = indexer.invert("labelIndex", "label2") + .transform(transformed) + .select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + // Check invert using only metadata + val inverse2 = new StringIndexerInverse() + .setInputCol("labelIndex") + .setOutputCol("label2") + val reversed2 = inverse2.transform(transformed).select("id", "label2") + assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === + reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexer with a numeric input column") {