-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark #8508
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,8 @@ | |
|
||
@inherit_doc | ||
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, | ||
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): | ||
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, | ||
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): | ||
""" | ||
Logistic regression. | ||
Currently, this class only supports binary classification. | ||
|
@@ -65,72 +66,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti | |
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
elasticNetParam = \ | ||
Param(Params._dummy(), "elasticNetParam", | ||
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + | ||
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") | ||
thresholds = Param(Params._dummy(), "thresholds", | ||
"Thresholds in multi-class classification" + | ||
" to adjust the probability of predicting each class." + | ||
" Array must have length equal to the number of classes, with values >= 0." + | ||
" The class with largest value p/t is predicted, where p is the original" + | ||
" probability of that class and t is the class' threshold.") | ||
threshold = Param(Params._dummy(), "threshold", | ||
"Threshold in binary classification prediction, in range [0, 1]." + | ||
" If threshold and thresholds are both set, they must match.") | ||
|
||
@keyword_only | ||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
threshold=0.5, thresholds=None, | ||
probabilityCol="probability", rawPredictionCol="rawPrediction"): | ||
threshold=0.5, thresholds=None, probabilityCol="probability", | ||
rawPredictionCol="rawPrediction", standardization=True): | ||
""" | ||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ | ||
threshold=0.5, thresholds=None, \ | ||
probabilityCol="probability", rawPredictionCol="rawPrediction") | ||
threshold=0.5, thresholds=None, probabilityCol="probability", \ | ||
rawPredictionCol="rawPrediction", standardization=True) | ||
If the threshold and thresholds Params are both set, they must be equivalent. | ||
""" | ||
super(LogisticRegression, self).__init__() | ||
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.classification.LogisticRegression", self.uid) | ||
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty | ||
# is an L2 penalty. For alpha = 1, it is an L1 penalty. | ||
self.elasticNetParam = \ | ||
Param(self, "elasticNetParam", | ||
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + | ||
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
#: param for whether to fit an intercept term. | ||
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") | ||
#: param for threshold in binary classification, in range [0, 1]. | ||
self.threshold = Param(self, "threshold", | ||
"Threshold in binary classification prediction, in range [0, 1]." + | ||
" If threshold and thresholds are both set, they must match.") | ||
#: param for thresholds or cutoffs in binary or multiclass classification | ||
self.thresholds = \ | ||
Param(self, "thresholds", | ||
"Thresholds in multi-class classification" + | ||
" to adjust the probability of predicting each class." + | ||
" Array must have length equal to the number of classes, with values >= 0." + | ||
" The class with largest value p/t is predicted, where p is the original" + | ||
" probability of that class and t is the class' threshold.") | ||
self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, | ||
fitIntercept=True, threshold=0.5) | ||
self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
self._checkThresholdConsistency() | ||
|
||
@keyword_only | ||
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
threshold=0.5, thresholds=None, | ||
probabilityCol="probability", rawPredictionCol="rawPrediction"): | ||
threshold=0.5, thresholds=None, probabilityCol="probability", | ||
rawPredictionCol="rawPrediction", standardization=True): | ||
""" | ||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ | ||
threshold=0.5, thresholds=None, \ | ||
probabilityCol="probability", rawPredictionCol="rawPrediction") | ||
threshold=0.5, thresholds=None, probabilityCol="probability", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
rawPredictionCol="rawPrediction", standardization=True) | ||
Sets params for logistic regression. | ||
If the threshold and thresholds Params are both set, they must be equivalent. | ||
""" | ||
|
@@ -142,32 +115,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre | |
def _create_model(self, java_model): | ||
return LogisticRegressionModel(java_model) | ||
|
||
def setElasticNetParam(self, value): | ||
""" | ||
Sets the value of :py:attr:`elasticNetParam`. | ||
""" | ||
self._paramMap[self.elasticNetParam] = value | ||
return self | ||
|
||
def getElasticNetParam(self): | ||
""" | ||
Gets the value of elasticNetParam or its default value. | ||
""" | ||
return self.getOrDefault(self.elasticNetParam) | ||
|
||
def setFitIntercept(self, value): | ||
""" | ||
Sets the value of :py:attr:`fitIntercept`. | ||
""" | ||
self._paramMap[self.fitIntercept] = value | ||
return self | ||
|
||
def getFitIntercept(self): | ||
""" | ||
Gets the value of fitIntercept or its default value. | ||
""" | ||
return self.getOrDefault(self.fitIntercept) | ||
|
||
def setThreshold(self, value): | ||
""" | ||
Sets the value of :py:attr:`threshold`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -432,6 +432,117 @@ def getStepSize(self): | |
return self.getOrDefault(self.stepSize) | ||
|
||
|
||
class HasElasticNetParam(Params): | ||
""" | ||
Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
|
||
def __init__(self): | ||
super(HasElasticNetParam, self).__init__() | ||
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. | ||
self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
self._setDefault(elasticNetParam=0.0) | ||
|
||
def setElasticNetParam(self, value): | ||
""" | ||
Sets the value of :py:attr:`elasticNetParam`. | ||
""" | ||
self._paramMap[self.elasticNetParam] = value | ||
return self | ||
|
||
def getElasticNetParam(self): | ||
""" | ||
Gets the value of elasticNetParam or its default value. | ||
""" | ||
return self.getOrDefault(self.elasticNetParam) | ||
|
||
|
||
class HasFitIntercept(Params): | ||
""" | ||
Mixin for param fitIntercept: whether to fit an intercept term.. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") | ||
|
||
def __init__(self): | ||
super(HasFitIntercept, self).__init__() | ||
#: param for whether to fit an intercept term. | ||
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") | ||
self._setDefault(fitIntercept=True) | ||
|
||
def setFitIntercept(self, value): | ||
""" | ||
Sets the value of :py:attr:`fitIntercept`. | ||
""" | ||
self._paramMap[self.fitIntercept] = value | ||
return self | ||
|
||
def getFitIntercept(self): | ||
""" | ||
Gets the value of fitIntercept or its default value. | ||
""" | ||
return self.getOrDefault(self.fitIntercept) | ||
|
||
|
||
class HasStandardization(Params): | ||
""" | ||
Mixin for param standardization: whether to standardize the training features before fitting the model.. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") | ||
|
||
def __init__(self): | ||
super(HasStandardization, self).__init__() | ||
#: param for whether to standardize the training features before fitting the model. | ||
self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") | ||
self._setDefault(standardization=True) | ||
|
||
def setStandardization(self, value): | ||
""" | ||
Sets the value of :py:attr:`standardization`. | ||
""" | ||
self._paramMap[self.standardization] = value | ||
return self | ||
|
||
def getStandardization(self): | ||
""" | ||
Gets the value of standardization or its default value. | ||
""" | ||
return self.getOrDefault(self.standardization) | ||
|
||
|
||
class HasThresholds(Params): | ||
""" | ||
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") | ||
|
||
def __init__(self): | ||
super(HasThresholds, self).__init__() | ||
#: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. | ||
self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") | ||
|
||
def setThresholds(self, value): | ||
""" | ||
Sets the value of :py:attr:`thresholds`. | ||
""" | ||
self._paramMap[self.thresholds] = value | ||
return self | ||
|
||
def getThresholds(self): | ||
""" | ||
Gets the value of thresholds or its default value. | ||
""" | ||
return self.getOrDefault(self.thresholds) | ||
|
||
|
||
class DecisionTreeParams(Params): | ||
""" | ||
Mixin for Decision Tree parameters. | ||
|
@@ -444,7 +555,7 @@ class DecisionTreeParams(Params): | |
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") | ||
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") | ||
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove whitespace |
||
|
||
def __init__(self): | ||
super(DecisionTreeParams, self).__init__() | ||
|
@@ -460,7 +571,7 @@ def __init__(self): | |
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") | ||
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. | ||
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
def setMaxDepth(self, value): | ||
""" | ||
Sets the value of :py:attr:`maxDepth`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,8 @@ | |
|
||
@inherit_doc | ||
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, | ||
HasRegParam, HasTol): | ||
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, | ||
HasStandardization): | ||
""" | ||
Linear regression. | ||
|
||
|
@@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction | |
TypeError: Method setParams forces keyword arguments. | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
elasticNetParam = \ | ||
Param(Params._dummy(), "elasticNetParam", | ||
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + | ||
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
|
||
@keyword_only | ||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
standardization=True): | ||
""" | ||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
standardization=True) | ||
""" | ||
super(LinearRegression, self).__init__() | ||
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.regression.LinearRegression", self.uid) | ||
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty | ||
# is an L2 penalty. For alpha = 1, it is an L1 penalty. | ||
self.elasticNetParam = \ | ||
Param(self, "elasticNetParam", | ||
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + | ||
"is an L2 penalty. For alpha = 1, it is an L1 penalty.") | ||
self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) | ||
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
standardization=True): | ||
""" | ||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) | ||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
standardization=True) | ||
Sets params for linear regression. | ||
""" | ||
kwargs = self.setParams._input_kwargs | ||
|
@@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre | |
def _create_model(self, java_model): | ||
return LinearRegressionModel(java_model) | ||
|
||
def setElasticNetParam(self, value): | ||
""" | ||
Sets the value of :py:attr:`elasticNetParam`. | ||
""" | ||
self._paramMap[self.elasticNetParam] = value | ||
return self | ||
|
||
def getElasticNetParam(self): | ||
""" | ||
Gets the value of elasticNetParam or its default value. | ||
""" | ||
return self.getOrDefault(self.elasticNetParam) | ||
|
||
|
||
class LinearRegressionModel(JavaModel): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should also extract a
HasThreshold
mixin for binary classifier thresholdsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threshold
is a deprecated parameter, it is replaced bythresholds
. LogisticRegression still reservethreshold
is just for binary compatibility. So I think we don't need to extractHasThreshold
as shared Param. @jkbradleyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that the
HasThresholds
trait mixin inml.LogisticRegression
is actually an artifact resulting from transient dependency throughProbabilisticClassifier
. We don't actually support multi-class classification inml.LogisticRegression
ATM and did quite a bit of work to make the API less confusing.After mutli-class is supported I think it makes sense to use
HasThresholds
, but for the time being I would prefer we only useHasThreshold
in the Python API.