Skip to content

Commit

Permalink
[SPARK-10573] [ML] IndexToString output schema should be StringType
Browse files Browse the repository at this point in the history
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata.

Author: Nick Pritchard <nicholas.pritchard@falkonry.com>

Closes #8751 from pnpritchard/SPARK-10573.
  • Loading branch information
pnpritchard authored and mengxr committed Sep 14, 2015
1 parent ce6f3f1 commit 8a634e9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.sql.types._

This comment has been minimized.

Copy link
@jaceklaskowski

jaceklaskowski Sep 15, 2015

Contributor

I'm confused with the change since the corresponding Suite in the change got that 4-class import added. What's the proper coding style for import?

This comment has been minimized.

Copy link
@pnpritchard

pnpritchard Sep 15, 2015

Author Contributor

This code change required one more class to be imported: org.apache.spark.sql.types.StructField. That would have made this import line have 5 classes. I believe the style guide suggests replacing imports with >=5 classes as a wildcard.

import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -229,8 +229,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
Expand Down Expand Up @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}

test("IndexToString.transformSchema (SPARK-10573)") {
val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
val inSchema = StructType(Seq(StructField("input", DoubleType)))
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === StringType)
}
}

0 comments on commit 8a634e9

Please sign in to comment.