-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spark-7446][MLLIB] Add inverse transform for string indexer #6339
Changes from all commits
bb16a6a
78b28c1
88779c1
557bef8
c3fdee1
e8bf3ad
d4bcb20
16cc3c3
a42d773
ecc65e0
03ebf95
ebed932
f3e0c64
5aa38bf
868b1a9
7464019
8450d0b
71e8d66
1987b95
a60c0e3
4f06c59
64dd3a3
21c8cfa
9e241d8
6a38edb
b9cffb6
7cdf915
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,13 +20,14 @@ package org.apache.spark.ml.feature | |
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.ml.attribute.NominalAttribute | ||
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types.{NumericType, StringType, StructType} | ||
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} | ||
import org.apache.spark.util.collection.OpenHashMap | ||
|
||
/** | ||
|
@@ -151,4 +152,105 @@ class StringIndexerModel private[ml] ( | |
val copied = new StringIndexerModel(uid, labels) | ||
copyValues(copied, extra) | ||
} | ||
|
||
/** | ||
* Return a model to perform the inverse transformation. | ||
* Note: By default we keep the original columns during this transformation, so the inverse | ||
* should only be used on new columns such as predicted labels. | ||
*/ | ||
def invert(inputCol: String, outputCol: String): StringIndexerInverse = { | ||
new StringIndexerInverse() | ||
.setInputCol(inputCol) | ||
.setOutputCol(outputCol) | ||
.setLabels(labels) | ||
} | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Transform a provided column back to the original input types using either the metadata | ||
* on the input column, or if provided using the labels supplied by the user. | ||
* Note: By default we keep the original columns during this transformation, | ||
* so the inverse should only be used on new columns such as predicted labels. | ||
*/ | ||
@Experimental | ||
class StringIndexerInverse private[ml] ( | ||
override val uid: String) extends Transformer | ||
with HasInputCol with HasOutputCol { | ||
|
||
def this() = | ||
this(Identifiable.randomUID("strIdxInv")) | ||
|
||
/** @group setParam */ | ||
def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
||
/** @group setParam */ | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** | ||
* Optional labels to be provided by the user, if not supplied column | ||
* metadata is read for labels. The default value is an empty array, | ||
* but the empty array is ignored and column metadata used instead. | ||
* @group setParam | ||
*/ | ||
def setLabels(value: Array[String]): this.type = set(labels, value) | ||
|
||
/** | ||
* Param for array of labels. | ||
* Optional labels to be provided by the user, if not supplied column | ||
* metadata is read for labels. | ||
* @group param | ||
*/ | ||
final val labels: StringArrayParam = new StringArrayParam(this, "labels", | ||
"array of labels, if not provided metadata from inputCol is used instead.") | ||
setDefault(labels, Array.empty[String]) | ||
|
||
/** | ||
* Optional labels to be provided by the user, if not supplied column | ||
* metadata is read for labels. | ||
* @group getParam | ||
*/ | ||
final def getLabels: Array[String] = $(labels) | ||
|
||
/** Transform the schema for the inverse transformation */ | ||
override def transformSchema(schema: StructType): StructType = { | ||
val inputColName = $(inputCol) | ||
val inputDataType = schema(inputColName).dataType | ||
require(inputDataType.isInstanceOf[NumericType], | ||
s"The input column $inputColName must be a numeric type, " + | ||
s"but got $inputDataType.") | ||
val inputFields = schema.fields | ||
val outputColName = $(outputCol) | ||
require(inputFields.forall(_.name != outputColName), | ||
s"Output column $outputColName already exists.") | ||
val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
val outputFields = inputFields :+ attr.toStructField() | ||
StructType(outputFields) | ||
} | ||
|
||
override def transform(dataset: DataFrame): DataFrame = { | ||
val inputColSchema = dataset.schema($(inputCol)) | ||
// If the labels array is empty use column metadata | ||
val values = if ($(labels).isEmpty) { | ||
Attribute.fromStructField(inputColSchema) | ||
.asInstanceOf[NominalAttribute].values.get | ||
} else { | ||
$(labels) | ||
} | ||
val indexer = udf { index: Double => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can all NumericTypes be cast to Double like this? Either test this in a unit test, or just switch to supporting DoubleType only for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't ever explicitly ask for it as a NumericType, is there something I'm missing here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The UDF method take a Double, and I'm wondering if it will work if the input data type is something odd like FractionalType. But now that you have the cast to DoubleType, I bet it's fine. |
||
val idx = index.toInt | ||
if (0 <= idx && idx < values.size) { | ||
values(idx) | ||
} else { | ||
throw new SparkException(s"Unseen index: $index ??") | ||
} | ||
} | ||
val outputColName = $(outputCol) | ||
dataset.select(col("*"), | ||
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) | ||
} | ||
|
||
override def copy(extra: ParamMap): StringIndexerInverse = { | ||
defaultCopy(extra) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,19 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
// a -> 0, b -> 2, c -> 1 | ||
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) | ||
assert(output === expected) | ||
// convert reverse our transform | ||
val reversed = indexer.invert("labelIndex", "label2") | ||
.transform(transformed) | ||
.select("id", "label2") | ||
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd be nice to check more carefully by zipping the original label column with the newly created one, and checking for equality. |
||
reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) | ||
// Check invert using only metadata | ||
val inverse2 = new StringIndexerInverse() | ||
.setInputCol("labelIndex") | ||
.setOutputCol("label2") | ||
val reversed2 = inverse2.transform(transformed).select("id", "label2") | ||
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === | ||
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) | ||
} | ||
|
||
test("StringIndexer with a numeric input column") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy doc here explaining that, if labels are not given, they will be taken from the inputCol metadata.
Same for built-in Param doc.