-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5956][MLLIB] Pipeline components should be copyable. #5820
Changes from all commits
f082a31
d882afc
9ee004e
53e0973
9286a22
0f4fd64
c76b4d1
5a67779
b642872
819dd2d
282a1a8
465dd12
2b954c3
463ecae
93e7924
f14456b
b2927b1
05229c3
7bef88d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,13 +34,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { | |
* Fits a single model to the input data with optional parameters. | ||
* | ||
* @param dataset input dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo. should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
* @param paramPairs Optional list of param pairs. | ||
* These values override any specified in this Estimator's embedded ParamMap. | ||
* @param firstParamPair the first param pair, overrides embedded params | ||
* @param otherParamPairs other param pairs. These values override any specified in this | ||
* Estimator's embedded ParamMap. | ||
* @return fitted model | ||
*/ | ||
@varargs | ||
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { | ||
val map = ParamMap(paramPairs: _*) | ||
def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { | ||
val map = new ParamMap() | ||
.put(firstParamPair) | ||
.put(otherParamPairs: _*) | ||
fit(dataset, map) | ||
} | ||
|
||
|
@@ -52,12 +55,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { | |
* These values override any specified in this Estimator's embedded ParamMap. | ||
* @return fitted model | ||
*/ | ||
def fit(dataset: DataFrame, paramMap: ParamMap): M | ||
def fit(dataset: DataFrame, paramMap: ParamMap): M = { | ||
copy(paramMap).fit(dataset) | ||
} | ||
|
||
/** | ||
* Fits a model to the input data. | ||
*/ | ||
def fit(dataset: DataFrame): M | ||
|
||
/** | ||
* Fits multiple models to the input data with multiple sets of parameters. | ||
* The default implementation uses a for loop on each parameter map. | ||
* Subclasses could overwrite this to optimize multi-model training. | ||
* Subclasses could override this to optimize multi-model training. | ||
* | ||
* @param dataset input dataset | ||
* @param paramMaps An array of parameter maps. | ||
|
@@ -67,4 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { | |
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { | ||
paramMaps.map(fit(dataset, _)) | ||
} | ||
|
||
override def copy(extra: ParamMap): Estimator[M] = { | ||
super.copy(extra).asInstanceOf[Estimator[M]] | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,15 @@ | |
package org.apache.spark.ml | ||
|
||
import org.apache.spark.annotation.AlphaComponent | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.ml.param.{ParamMap, Params} | ||
import org.apache.spark.sql.DataFrame | ||
|
||
/** | ||
* :: AlphaComponent :: | ||
* Abstract class for evaluators that compute metrics from predictions. | ||
*/ | ||
@AlphaComponent | ||
abstract class Evaluator extends Identifiable { | ||
abstract class Evaluator extends Params { | ||
|
||
/** | ||
* Evaluates the output. | ||
|
@@ -36,5 +35,18 @@ abstract class Evaluator extends Identifiable { | |
* @param paramMap parameter map that specifies the input columns and output metrics | ||
* @return metric | ||
*/ | ||
def evaluate(dataset: DataFrame, paramMap: ParamMap): Double | ||
def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { | ||
this.copy(paramMap).evaluate(dataset) | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One question. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
/** | ||
* Evaluates the output. | ||
* @param dataset a dataset that contains labels/observations and predictions. | ||
* @return metric | ||
*/ | ||
def evaluate(dataset: DataFrame): Double | ||
|
||
override def copy(extra: ParamMap): Evaluator = { | ||
super.copy(extra).asInstanceOf[Evaluator] | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it will be more consistent to most of Java/Scala APIs by asking users to implement
clone
method, and we just have defaultcopy
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There have been many issues with
clone
. So I would avoid touching it. See:https://developers.google.com/java-dev-tools/codepro/doc/features/audit/audit_rules_com.instantiations.assist.eclipse.auditGroup.cloneUsage