Skip to content

Commit

Permalink
Added checks for isLargerBetter()
Browse files Browse the repository at this point in the history
  • Loading branch information
noel-smith committed Aug 24, 2015
1 parent 63b3835 commit 7794cf7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 6 additions & 0 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def evaluate(self, dataset, params=None):
else:
raise ValueError("Params must be a param map but got %s." % type(params))

def isLargerBetter(self):
return True


@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
Expand All @@ -85,6 +88,9 @@ def _evaluate(self, dataset):
self._transfer_params_to_java()
return self._java_obj.evaluate(dataset._jdf)

def isLargerBetter(self):
return self._java_obj.isLargerBetter()


@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def _fit(self, dataset):
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, epm[j]))
metrics[j] += metric
bestIndex = np.argmax(metrics)

if eva.isLargerBetter():
bestIndex = np.argmax(metrics)
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)

Expand Down

0 comments on commit 7794cf7

Please sign in to comment.