Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-39916][SQL][MLLIB][REFACTOR] Merge ml SchemaUtils to SQL #37336

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Expand Up @@ -18,13 +18,14 @@
package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.ml.attribute.NumericAttribute
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.sql.util.SchemaUtils

/**
* (private[ml]) Trait for parameters for prediction (regression and classification).
Expand Down Expand Up @@ -175,7 +176,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
outputSchema = NumericAttribute.updateNumeric(outputSchema, $(predictionCol))
}
outputSchema
}
Expand Down
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.ml.attribute
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils

/**
* Attributes that describe a vector ML column.
Expand Down Expand Up @@ -243,4 +244,21 @@ object AttributeGroup {
new AttributeGroup(field.name)
}
}

/**
* Update the size of a ML Vector column. If this column do not exist, append it.
* @param schema input schema
* @param colName column name
* @param size number of features
* @return new schema
*/
def updateAttributeGroupSize(
schema: StructType,
colName: String,
size: Int): StructType = {
require(size > 0)
val attrGroup = new AttributeGroup(colName, size)
val field = attrGroup.toStructField
SchemaUtils.updateField(schema, field, true)
}
}
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.ml.attribute

import scala.annotation.varargs

import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField}
import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils

/**
* Abstract class for ML attributes.
Expand Down Expand Up @@ -307,6 +308,21 @@ object NumericAttribute extends AttributeFactory {
val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None
new NumericAttribute(name, index, min, max, std, sparsity)
}

/**
* Update the numeric meta of an existing column. If this column do not exist, append it.
* @param schema input schema
* @param colName column name
* @return new schema
*/
def updateNumeric(
schema: StructType,
colName: String): StructType = {
val attr = NumericAttribute.defaultAttr
.withName(colName)
val field = attr.toStructField
SchemaUtils.updateField(schema, field, true)
}
}

/**
Expand Down Expand Up @@ -469,6 +485,24 @@ object NominalAttribute extends AttributeFactory {
if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None
new NominalAttribute(name, index, isOrdinal, numValues, values)
}

/**
* Update the number of values of an existing column. If this column do not exist, append it.
* @param schema input schema
* @param colName column name
* @param numValues number of values.
* @return new schema
*/
def updateNumValues(
schema: StructType,
colName: String,
numValues: Int): StructType = {
val attr = NominalAttribute.defaultAttr
.withName(colName)
.withNumValues(numValues)
val field = attr.toStructField
SchemaUtils.updateField(schema, field, true)
}
}

/**
Expand Down
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.SchemaUtils

/**
* (private[spark]) Params for classification.
Expand Down Expand Up @@ -81,11 +82,11 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(schema,
outputSchema = NominalAttribute.updateNumValues(schema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
outputSchema = AttributeGroup.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
Expand Down
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTMod
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils

/**
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
Expand Down
Expand Up @@ -164,11 +164,11 @@ final class OneVsRestModel private[ml] (
var outputSchema = validateAndTransformSchema(schema, fitting = false,
getClassifier.featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
outputSchema = NominalAttribute.updateNumValues(outputSchema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
outputSchema = AttributeGroup.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
Expand Down
Expand Up @@ -18,13 +18,14 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.SchemaUtils

/**
* (private[classification]) Params for probabilistic classification.
Expand Down Expand Up @@ -88,7 +89,7 @@ abstract class ProbabilisticClassificationModel[
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(probabilityCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
outputSchema = AttributeGroup.updateAttributeGroupSize(outputSchema,
$(probabilityCol), numClasses)
}
outputSchema
Expand Down
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down
Expand Up @@ -21,19 +21,20 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel


Expand Down Expand Up @@ -79,7 +80,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
VectorUDT.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -126,7 +127,7 @@ class BisectingKMeansModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
outputSchema = NominalAttribute.updateNumValues(outputSchema,
$(predictionCol), parentModel.k)
}
outputSchema
Expand All @@ -151,7 +152,7 @@ class BisectingKMeansModel private[ml] (
"ClusteringEvaluator instead. You can also get the cost on the training dataset in the " +
"summary.", "3.0.0")
def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
VectorUDT.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
val data = columnToOldVector(dataset, getFeaturesCol)
parentModel.computeCost(data)
}
Expand Down
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute}
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
Expand All @@ -30,15 +31,14 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel


/**
* Common params for GaussianMixture and GaussianMixtureModel
*/
Expand Down Expand Up @@ -68,7 +68,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
VectorUDT.validateVectorCompatibleColumn(schema, getFeaturesCol)
val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
}
Expand Down Expand Up @@ -152,11 +152,11 @@ class GaussianMixtureModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
outputSchema = NominalAttribute.updateNumValues(outputSchema,
$(predictionCol), weights.length)
}
if ($(probabilityCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
outputSchema = AttributeGroup.updateAttributeGroupSize(outputSchema,
$(probabilityCol), weights.length)
}
outputSchema
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
Expand All @@ -36,6 +37,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

Expand Down Expand Up @@ -118,7 +120,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
VectorUDT.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -167,7 +169,7 @@ class KMeansModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
outputSchema = NominalAttribute.updateNumValues(outputSchema,
$(predictionCol), parentModel.k)
}
outputSchema
Expand Down
11 changes: 5 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Expand Up @@ -29,17 +29,15 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LDAUtils => OldLDAUtils, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LDAUtils => OldLDAUtils, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
Expand All @@ -48,6 +46,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils
Expand Down Expand Up @@ -353,7 +352,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
VectorUDT.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

Expand Down Expand Up @@ -511,7 +510,7 @@ abstract class LDAModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema)
if ($(topicDistributionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
outputSchema = AttributeGroup.updateAttributeGroupSize(outputSchema,
$(topicDistributionCol), oldLocalModel.k)
}
outputSchema
Expand Down
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPower
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

/**
* Common params for PowerIterationClustering
Expand Down