Skip to content

Commit

Permalink
Updated Attribute.fromStructField to allow any NumericType, rather th…
Browse files Browse the repository at this point in the history
…an just DoubleType, and added unit tests for a few of the other NumericTypes.
  • Loading branch information
dusenberrymw committed May 31, 2015
1 parent 4b5f12b commit 87fecb3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
import scala.annotation.varargs

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
* Creates an [[Attribute]] from a [[StructField]] instance.
*/
def fromStructField(field: StructField): Attribute = {
require(field.dataType == DoubleType)
require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
// Attribute.fromStructField should accept any NumericType, not just DoubleType
val longFldWithMeta = new StructField("x", LongType, false, metadata)
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}

0 comments on commit 87fecb3

Please sign in to comment.