diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 6b0a9ffde9f42..9594cb344ec18 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -66,6 +66,9 @@ def evaluate(self, dataset, params=None): else: raise ValueError("Params must be a param map but got %s." % type(params)) + def isLargerBetter(self): + return True + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): @@ -85,6 +88,9 @@ def _evaluate(self, dataset): self._transfer_params_to_java() return self._java_obj.evaluate(dataset._jdf) + def isLargerBetter(self): + return self._java_obj.isLargerBetter() + @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index dcfee6a3170ab..7f6bbebfe0f63 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -223,7 +223,11 @@ def _fit(self, dataset): # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric - bestIndex = np.argmax(metrics) + + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel)