Skip to content

Commit

Permalink
[SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow …
Browse files Browse the repository at this point in the history
…any NumericType.

Updated `Attribute.fromStructField` to allow any `NumericType`, rather than just `DoubleType`, and added unit tests for a few of the other NumericTypes.

Author: Mike Dusenberry <dusenberrymw@gmail.com>

Closes #6540 from dusenberrymw/SPARK-7426_AttributeFactory.fromStructField_Should_Allow_NumericTypes and squashes the following commits:

87fecb3 [Mike Dusenberry] Updated Attribute.fromStructField to allow any NumericType, rather than just DoubleType, and added unit tests for a few of the other NumericTypes.
  • Loading branch information
dusenberrymw authored and jkbradley committed Jun 22, 2015
1 parent a189442 commit 47c1d56
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
import scala.annotation.varargs

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
* Creates an [[Attribute]] from a [[StructField]] instance.
*/
def fromStructField(field: StructField): Attribute = {
require(field.dataType == DoubleType)
require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
val fldWithMeta = new StructField("x", DoubleType, false, metadata)
assert(Attribute.fromStructField(fldWithMeta).isNumeric)
// Attribute.fromStructField should accept any NumericType, not just DoubleType
val longFldWithMeta = new StructField("x", LongType, false, metadata)
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}

0 comments on commit 47c1d56

Please sign in to comment.