Skip to content

Commit

Permalink
[SPARK-29914][ML] ML models attach metadata in transform/`transform…
Browse files Browse the repository at this point in the history
…Schema`

### What changes were proposed in this pull request?
1, `predictionCol` in `ml.classification` & `ml.clustering` add `NominalAttribute`
2, `rawPredictionCol` in `ml.classification` add `AttributeGroup` containing vectorsize=`numClasses`
3, `probabilityCol` in `ml.classification` & `ml.clustering` add `AttributeGroup` containing vectorsize=`numClasses`/`k`
4, `leafCol` in GBT/RF  add `AttributeGroup` containing vectorsize=`numTrees`
5, `leafCol` in DecisionTree  add `NominalAttribute`
6, `outputCol` in models in `ml.feature` add `AttributeGroup` containing vectorsize
7, `outputCol` in `UnaryTransformer`s in `ml.feature` add `AttributeGroup` containing vectorsize

### Why are the changes needed?
Appened metadata can be used in downstream ops, like `Classifier.getNumClasses`

There are many impls (like `Binarizer`/`Bucketizer`/`VectorAssembler`/`OneHotEncoder`/`FeatureHasher`/`HashingTF`/`VectorSlicer`/...) in `.ml` that append appropriate metadata in `transform`/`transformSchema` method.

However there are also many impls return no metadata in transformation, even some metadata like `vector.size`/`numAttrs`/`attrs` can be ealily inferred.

### Does this PR introduce any user-facing change?
Yes, add some metadatas in transformed dataset.

### How was this patch tested?
existing testsuites and added testsuites

Closes #26547 from zhengruifeng/add_output_vecSize.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
  • Loading branch information
zhengruifeng committed Dec 4, 2019
1 parent 55132ae commit 710ddab
Show file tree
Hide file tree
Showing 51 changed files with 593 additions and 105 deletions.
17 changes: 12 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Expand Up @@ -88,8 +88,9 @@ private[ml] trait PredictorParams extends Params
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
protected def extractInstances(dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
protected def extractInstances(
dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
Expand Down Expand Up @@ -222,7 +223,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
protected def featuresDataType: DataType = new VectorUDT

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, featuresDataType)
var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
}
outputSchema
}

/**
Expand All @@ -244,10 +249,12 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}

protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
outputSchema($(predictionCol)).metadata)
}

/**
Expand Down
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Expand Up @@ -117,9 +117,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
outputSchema($(outputCol)).metadata)
}

override def copy(extra: ParamMap): T = defaultCopy(extra)
Expand Down
Expand Up @@ -48,8 +48,9 @@ private[spark] trait ClassifierParams
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
protected def extractInstances(dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
protected def extractInstances(
dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
Expand Down Expand Up @@ -183,6 +184,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
/** Number of classes (values which the label can take). */
def numClasses: Int

override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(schema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
}

/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
Expand All @@ -193,29 +207,31 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
val predictRawUDF = udf { (features: Any) =>
val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if (getPredictionCol != "") {
val predUDF = if (getRawPredictionCol != "") {
val predCol = if (getRawPredictionCol != "") {
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
val predictUDF = udf { (features: Any) =>
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col(getFeaturesCol))
}
outputData = outputData.withColumn(getPredictionCol, predUDF)
outputData = outputData.withColumn(getPredictionCol, predCol,
outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}

Expand Down
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -202,13 +203,23 @@ class DecisionTreeClassificationModel private[ml] (
rootNode.predictImpl(features).prediction
}

@Since("3.0.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType

/**
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
Expand Down Expand Up @@ -291,13 +292,23 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Expand Up @@ -161,13 +161,23 @@ final class OneVsRestModel private[ml] (

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
var outputSchema = validateAndTransformSchema(schema, fitting = false,
getClassifier.featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) {
logWarning(s"$uid: OneVsRestModel.transform() does nothing" +
Expand Down Expand Up @@ -230,6 +240,7 @@ final class OneVsRestModel private[ml] (

predictionColNames :+= getRawPredictionCol
predictionColumns :+= rawPredictionUDF(col(accColName))
.as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata)
}

if (getPredictionCol.nonEmpty) {
Expand Down
Expand Up @@ -90,6 +90,15 @@ abstract class ProbabilisticClassificationModel[
set(thresholds, value).asInstanceOf[M]
}

override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(probabilityCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(probabilityCol), numClasses)
}
outputSchema
}

/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
Expand All @@ -101,7 +110,7 @@ abstract class ProbabilisticClassificationModel[
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
Expand All @@ -113,36 +122,39 @@ abstract class ProbabilisticClassificationModel[
var outputData = dataset
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
val predictRawUDF = udf { (features: Any) =>
val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
val probUDF = if ($(rawPredictionCol).nonEmpty) {
val probCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
val probabilityUDF = udf { (features: Any) =>
val probabilityUDF = udf { features: Any =>
predictProbability(features.asInstanceOf[FeaturesType])
}
probabilityUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(probabilityCol), probUDF)
outputData = outputData.withColumn($(probabilityCol), probCol,
outputSchema($(probabilityCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
val predUDF = if ($(rawPredictionCol).nonEmpty) {
val predCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2prediction _).apply(col($(rawPredictionCol)))
} else if ($(probabilityCol).nonEmpty) {
udf(probability2prediction _).apply(col($(probabilityCol)))
} else {
val predictUDF = udf { (features: Any) =>
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(predictionCol), predUDF)
outputData = outputData.withColumn($(predictionCol), predCol,
outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}

Expand Down
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -210,13 +211,23 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Expand Up @@ -110,15 +110,21 @@ class BisectingKMeansModel private[ml] (

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol),
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
outputSchema($(predictionCol)).metadata)
}

@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
$(predictionCol), parentModel.k)
}
outputSchema
}

@Since("3.0.0")
Expand Down

0 comments on commit 710ddab

Please sign in to comment.