Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12346] [ML] Missing attribute names in GLM for vector-type features #10323

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String)
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.map { attr =>
group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
} else {
attr
attr.withName(c + "_" + i)
}
}
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.fill(numAttrs)(NumericAttribute.defaultAttr)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
Expand Down
Expand Up @@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(attrs === expectedAttrs)
}

test("vector attribute generation") {
val formula = new RFormula().setFormula("id ~ vec")
val original = sqlContext.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support term in R formula is type vector? I think it's illegal in R.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense when using RFormula in a ML pipeline (not necessarily in R).

).toDF("id", "vec")
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec_0"), Some(1)),
new NumericAttribute(Some("vec_1"), Some(2))))
assert(attrs === expectedAttrs)
}

test("vector attribute generation with unnamed input attrs") {
val formula = new RFormula().setFormula("id ~ vec2")
val base = sqlContext.createDataFrame(
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
).toDF("id", "vec")
val metadata = new AttributeGroup(
"vec2",
Array[Attribute](
NumericAttribute.defaultAttr,
NumericAttribute.defaultAttr)).toMetadata
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec2_0"), Some(1)),
new NumericAttribute(Some("vec2_1"), Some(2))))
assert(attrs === expectedAttrs)
}

test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame(
Expand Down
Expand Up @@ -111,8 +111,8 @@ class VectorAssemblerSuite
assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
val userSalaryOut = features.getAttr(4)
assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
}

test("read/write") {
Expand Down