Skip to content

Commit

Permalink
Add fold number check in Python.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 10, 2020
1 parent 2d8956f commit aa7c8d0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
31 changes: 31 additions & 0 deletions python/pyspark/ml/tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,37 @@ def test_user_specified_folds(self):
loadedCV = CrossValidator.load(cvPath)
self.assertEqual(loadedCV.getFoldCol(), cv_with_user_folds.getFoldCol())

def test_invalid_user_specified_folds(self):
from pyspark.sql import functions as F

dataset_with_folds = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0, 0),
(Vectors.dense([0.4]), 1.0, 1),
(Vectors.dense([0.5]), 0.0, 2),
(Vectors.dense([0.6]), 1.0, 0),
(Vectors.dense([1.0]), 1.0, 1)] * 10,
["features", "label", "fold"])

lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [20]).build()
evaluator = BinaryClassificationEvaluator()

cv = CrossValidator(estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
numFolds=2,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
cv.fit(dataset_with_folds)

cv = CrossValidator(estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
numFolds=4,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
cv.fit(dataset_with_folds)


class TrainValidationSplitTests(SparkSessionTestCase):

Expand Down
26 changes: 19 additions & 7 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import col, lit, rand
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
from pyspark.sql.types import BooleanType

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
Expand Down Expand Up @@ -203,9 +204,8 @@ class _CrossValidatorParams(_ValidatorParams):
foldCol = Param(Params._dummy(), "foldCol", "Param for the column name of user " +
"specified fold number. Once this is specified, :py:class:`CrossValidator` " +
"won't do random k-fold split. Note that this column should be integer type " +
"with range [0, numFolds) and Spark will also mod the user-specified " +
"fold values with the value of :py:attr:`numFolds` param.",
typeConverter=TypeConverters.toString)
"with range [0, numFolds) and Spark will throw exception on out-of-range " +
"fold numbers.", typeConverter=TypeConverters.toString)

@since("1.4.0")
def getNumFolds(self):
Expand Down Expand Up @@ -399,10 +399,22 @@ def _kFold(self, dataset):
datasets.append((train, validation))
else:
# Use user-specified fold numbers.
dfWithMod = dataset.withColumn(foldCol, col(foldCol) % lit(nFolds))
def checker(foldNum):
if foldNum < 0 or foldNum >= nFolds:
raise ValueError(
"Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum))
return True

checker_udf = UserDefinedFunction(checker, BooleanType())
for i in range(nFolds):
datasets.append((dfWithMod.filter(col(foldCol) != lit(i)),
dfWithMod.filter(col(foldCol) == lit(i))))
training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
validation = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i)))
if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0:
raise ValueError("The training data at fold %s is empty." % i)
if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0:
raise ValueError("The validation data at fold %s is empty." % i)
datasets.append((training, validation))

return datasets

@since("1.4.0")
Expand Down

0 comments on commit aa7c8d0

Please sign in to comment.