From b9d8da42d0a00e6b34de340b7223f9140d7f04f6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 16:39:17 -0700 Subject: [PATCH 1/3] rename StringIndexerReverse to IndexToString --- .../apache/spark/ml/feature/StringIndexer.scala | 17 +++++++---------- .../spark/ml/feature/StringIndexerSuite.scala | 6 ++++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e4485eb038409..49c8ad775dd6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -172,14 +172,11 @@ class StringIndexerModel private[ml] ( } /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. + * Return a [[IndexToString]] instance to perform the inverse transformation, mapping indices back + * to their string values. */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) + def inverse: IndexToString = { + new IndexToString() .setLabels(labels) } } @@ -192,12 +189,12 @@ class StringIndexerModel private[ml] ( * so the inverse should only be used on new columns such as predicted labels. */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -268,7 +265,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b111036087e6a..b157b68c45656 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,13 +49,15 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") + val reversed = indexer.inverse + .setInputCol("labelIndex") + .setOutputCol("label2") .transform(transformed) .select("id", "label2") assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) // Check invert using only metadata - val inverse2 = new StringIndexerInverse() + val inverse2 = new IndexToString() .setInputCol("labelIndex") .setOutputCol("label2") val reversed2 = inverse2.transform(transformed).select("id", "label2") From f602dff7bba82c830f5695a20e2bd83db893a54e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 08:28:02 -0700 Subject: [PATCH 2/3] update doc --- .../apache/spark/ml/feature/StringIndexer.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 49c8ad775dd6f..05ac5f69218dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -172,8 +172,9 @@ class StringIndexerModel private[ml] ( } /** - * Return a [[IndexToString]] instance to perform the inverse transformation, mapping indices back - * to their string values. + * Returns an [[IndexToString]] transformer that can perform the inverse transformation, mapping + * indices back to their string values. + * Users need to set input/output column names on the transformer returned. */ def inverse: IndexToString = { new IndexToString() @@ -183,10 +184,12 @@ class StringIndexerModel private[ml] ( /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] */ @Experimental class IndexToString private[ml] ( @@ -254,7 +257,7 @@ class IndexToString private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") From 70bbad431d575a1179f910e296a7c42702a7e39c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 14:22:51 -0700 Subject: [PATCH 3/3] remove inverse and update tests --- .../spark/ml/feature/StringIndexer.scala | 16 ++---- .../spark/ml/feature/StringIndexerSuite.scala | 52 +++++++++++++------ 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 05ac5f69218dd..25e6d5827b06d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -170,16 +172,6 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra) } - - /** - * Returns an [[IndexToString]] transformer that can perform the inverse transformation, mapping - * indices back to their string values. - * Users need to set input/output column names on the transformer returned. - */ - def inverse: IndexToString = { - new IndexToString() - .setLabels(labels) - } } /** @@ -189,7 +181,7 @@ class StringIndexerModel private[ml] ( * supplied by the user. * All original columns are kept during transformation. * - * @see [[StringIndexer]] + * @see [[StringIndexer]] for converting strings into indices */ @Experimental class IndexToString private[ml] ( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b157b68c45656..4cff89e437860 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -48,21 +49,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.inverse - .setInputCol("labelIndex") - .setOutputCol("label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new IndexToString() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexerUnseen") { @@ -122,4 +108,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } }