Skip to content

Commit

Permalink
Restore accidentatly removed setDefault, add back newline between ini…
Browse files Browse the repository at this point in the history
…t and setters in tree params codegen, regen shared params
  • Loading branch information
holdenk committed Jan 26, 2016
1 parent 0d28922 commit 8396aef
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 1 addition & 2 deletions python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,11 @@ def __init__(self):
super(DecisionTreeParams, self).__init__()'''
dtParamMethods = ""
dummyPlaceholders = ""
realParams = ""
paramTemplate = """$name = Param($owner, "$name", "$doc")"""
for name, doc in decisionTreeParams:
variable = paramTemplate.replace("$name", name).replace("$doc", doc)
dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n "
dtParamMethods += _gen_param_code(name, doc, None) + "\n"
code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) +
code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" +
dtParamMethods)
print("\n\n\n".join(code))
1 change: 1 addition & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ class DecisionTreeParams(Params):

def __init__(self):
super(DecisionTreeParams, self).__init__()

def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
"""
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

Expand Down

0 comments on commit 8396aef

Please sign in to comment.