Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5956][MLLIB] Pipeline components should be copyable. #5820

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
Expand Down Expand Up @@ -129,16 +128,16 @@ MyJavaLogisticRegression setMaxIter(int value) {

// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();

// Do learning to estimate the weight vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.

// Create a model, and return it.
return new MyJavaLogisticRegressionModel(this, paramMap, weights);
return new MyJavaLogisticRegressionModel(this, weights);
}
}

Expand All @@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel
private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }

private ParamMap fittingParamMap_;
public ParamMap fittingParamMap() { return fittingParamMap_; }

private Vector weights_;
public Vector weights() { return weights_; }

public MyJavaLogisticRegressionModel(
MyJavaLogisticRegression parent_,
ParamMap fittingParamMap_,
Vector weights_) {
public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
this.fittingParamMap_ = fittingParamMap_;
this.weights_ = weights_;
}

Expand Down Expand Up @@ -210,10 +202,8 @@ public Vector predictRaw(Vector features) {
* In Java, we have to make this method public since Java does not understand Scala's protected
* modifier.
*/
public MyJavaLogisticRegressionModel copy() {
MyJavaLogisticRegressionModel m =
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
return m;
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public static void main(String[] args) {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());

// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
Expand All @@ -87,7 +87,7 @@ public static void main(String[] args) {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());

// Prepare test documents.
List<LabeledPoint> localTest = Lists.newArrayList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
dt.asInstanceOf[DecisionTreeRegressor])
val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.examples.ml

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
Expand Down Expand Up @@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
def getMaxIter: Int = getOrDefault(maxIter)
def getMaxIter: Int = $(maxIter)
}

/**
Expand All @@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)

// This method is used by fit()
override protected def train(
dataset: DataFrame,
paramMap: ParamMap): MyLogisticRegressionModel = {
override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
val oldDataset = extractLabeledPoints(dataset, paramMap)
val oldDataset = extractLabeledPoints(dataset)

// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.

// Create a model, and return it.
new MyLogisticRegressionModel(this, paramMap, weights)
new MyLogisticRegressionModel(this, weights)
}
}

Expand All @@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
Expand Down Expand Up @@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
override protected def copy(): MyLogisticRegressionModel = {
val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
Params.inheritValues(extractParamMap(), this, m)
m
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it will be more consistent to most of Java/Scala APIs by asking users to implement clone method, and we just have default copy method?

def copy(extra: ParamMap): MyLogisticRegressionModel = {
  copyValues(this.clone(), extra)       
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
dt.asInstanceOf[RandomForestClassifier])
val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
dt.asInstanceOf[RandomForestRegressor])
val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
println("Model 1 was fit using parameters: " + model1.fittingParamMap)
println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())

// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
Expand All @@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())

// Prepare test data.
val test = sc.parallelize(Seq(
Expand Down
26 changes: 20 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo. should be override, not overwrites

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

* @param paramPairs Optional list of param pairs.
* These values override any specified in this Estimator's embedded ParamMap.
* @param firstParamPair the first param pair, overrides embedded params
* @param otherParamPairs other param pairs. These values override any specified in this
* Estimator's embedded ParamMap.
* @return fitted model
*/
@varargs
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = ParamMap(paramPairs: _*)
def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
fit(dataset, map)
}

Expand All @@ -52,12 +55,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
def fit(dataset: DataFrame, paramMap: ParamMap): M
def fit(dataset: DataFrame, paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}

/**
* Fits a model to the input data.
*/
def fit(dataset: DataFrame): M

/**
* Fits multiple models to the input data with multiple sets of parameters.
* The default implementation uses a for loop on each parameter map.
* Subclasses could overwrite this to optimize multi-model training.
* Subclasses could override this to optimize multi-model training.
*
* @param dataset input dataset
* @param paramMaps An array of parameter maps.
Expand All @@ -67,4 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}

override def copy(extra: ParamMap): Estimator[M] = {
super.copy(extra).asInstanceOf[Estimator[M]]
}
}
20 changes: 16 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
package org.apache.spark.ml

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.DataFrame

/**
* :: AlphaComponent ::
* Abstract class for evaluators that compute metrics from predictions.
*/
@AlphaComponent
abstract class Evaluator extends Identifiable {
abstract class Evaluator extends Params {

/**
* Evaluates the output.
Expand All @@ -36,5 +35,18 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
this.copy(paramMap).evaluate(dataset)
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question. Why Evaluator returns type Double? For example, confusion matrix can be an evaluator as well, and it can not be represented by type Double. Also, in our use-case, we often evaluate the models by the histogram of recommended popular titles, so it will be Array[Long]. Why don't we have type T here, and we specify the type in the implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Evaluator is used for tuning model, where we need a scalar metric to select the best model. We can try returning a comparable T, but this is not part of this PR.

/**
* Evaluates the output.
* @param dataset a dataset that contains labels/observations and predictions.
* @return metric
*/
def evaluate(dataset: DataFrame): Double

override def copy(extra: ParamMap): Evaluator = {
super.copy(extra).asInstanceOf[Evaluator]
}
}
9 changes: 4 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ abstract class Model[M <: Model[M]] extends Transformer {
*/
val parent: Estimator[M]

/**
* Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
* Note: For ensembles' component Models, this value can be null.
*/
val fittingParamMap: ParamMap
override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
}
}
Loading