-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10509][PYSPARK] Reduce excessive param boiler plate code #10216
Changes from 11 commits
564339b
429172b
a976b78
384fd0d
8e8cbae
79642e0
e0f3f00
7aecb59
c4a2919
cb7d468
53edd3d
69025f1
a39bea5
10ed8da
b755008
0d28922
8396aef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,7 +72,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti | |
.. versionadded:: 1.3.0 | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
threshold = Param(Params._dummy(), "threshold", | ||
"Threshold in binary classification prediction, in range [0, 1]." + | ||
" If threshold and thresholds are both set, they must match.") | ||
|
@@ -93,9 +92,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.classification.LogisticRegression", self.uid) | ||
#: param for threshold in binary classification, in range [0, 1]. | ||
self.threshold = Param(self, "threshold", | ||
"Threshold in binary classification prediction, in range [0, 1]." + | ||
" If threshold and thresholds are both set, they must match.") | ||
self.threshold = LogisticRegression.threshold._copy_new_parent(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh good point - I'll remove the comments about them just being dummy params. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope we can eliminate the need for these calls in concrete classes' init methods. |
||
self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
|
@@ -232,7 +229,6 @@ class TreeClassifierParams(object): | |
""" | ||
supportedImpurities = ["entropy", "gini"] | ||
|
||
# a placeholder to make it appear in the generated doc | ||
impurity = Param(Params._dummy(), "impurity", | ||
"Criterion used for information gain calculation (case-insensitive). " + | ||
"Supported options: " + | ||
|
@@ -241,9 +237,7 @@ class TreeClassifierParams(object): | |
def __init__(self): | ||
super(TreeClassifierParams, self).__init__() | ||
#: param for Criterion used for information gain calculation (case-insensitive). | ||
self.impurity = Param(self, "impurity", "Criterion used for information " + | ||
"gain calculation (case-insensitive). Supported options: " + | ||
", ".join(self.supportedImpurities)) | ||
self.impurity = TreeClassifierParams.impurity._copy_new_parent(self) | ||
|
||
@since("1.6.0") | ||
def setImpurity(self, value): | ||
|
@@ -485,7 +479,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol | |
.. versionadded:: 1.4.0 | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
lossType = Param(Params._dummy(), "lossType", | ||
"Loss function which GBT tries to minimize (case-insensitive). " + | ||
"Supported options: " + ", ".join(GBTParams.supportedLossTypes)) | ||
|
@@ -505,9 +498,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.classification.GBTClassifier", self.uid) | ||
#: param for Loss function which GBT tries to minimize (case-insensitive). | ||
self.lossType = Param(self, "lossType", | ||
"Loss function which GBT tries to minimize (case-insensitive). " + | ||
"Supported options: " + ", ".join(GBTParams.supportedLossTypes)) | ||
self.lossType = GBTClassifier.lossType._copy_new_parent(self) | ||
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, | ||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, | ||
lossType="logistic", maxIter=20, stepSize=0.1) | ||
|
@@ -597,7 +588,6 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H | |
.. versionadded:: 1.5.0 | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + | ||
"default is 1.0") | ||
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + | ||
|
@@ -616,12 +606,9 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.classification.NaiveBayes", self.uid) | ||
#: param for the smoothing parameter. | ||
self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + | ||
"default is 1.0") | ||
self.smoothing = NaiveBayes.smoothing._copy_new_parent(self) | ||
#: param for the model type. | ||
self.modelType = Param(self, "modelType", "The model type which is a string " + | ||
"(case-sensitive). Supported options: multinomial (default) " + | ||
"and bernoulli.") | ||
self.modelType = NaiveBayes.modelType._copy_new_parent(self) | ||
self._setDefault(smoothing=1.0, modelType="multinomial") | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
|
@@ -734,7 +721,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, | |
.. versionadded:: 1.6.0 | ||
""" | ||
|
||
# a placeholder to make it appear in the generated doc | ||
layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + | ||
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + | ||
"neurons and output layer of 10 neurons, default is [1, 1].") | ||
|
@@ -753,14 +739,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred | |
super(MultilayerPerceptronClassifier, self).__init__() | ||
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) | ||
self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + | ||
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + | ||
"100 neurons and output layer of 10 neurons, default is [1, 1].") | ||
self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + | ||
"matrices. Data is stacked within partitions. If block size is " + | ||
"more than remaining data in a partition then it is adjusted to " + | ||
"the size of this data. Recommended size is between 10 and 1000, " + | ||
"default is 128.") | ||
self.layers = MultilayerPerceptronClassifier.layers._copy_new_parent(self) | ||
self.blockSize = MultilayerPerceptronClassifier.blockSize._copy_new_parent(self) | ||
self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can hopefully remain the same.