From b7a7b9b105357e3d8725bdb295388abff762648a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 May 2015 11:51:19 -0700 Subject: [PATCH 1/2] simplify grid build --- python/pyspark/ml/tuning.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a383bd0c0d26f..45ee862bff164 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,6 +15,8 @@ # limitations under the License. # +import itertools + __all__ = ['ParamGridBuilder'] @@ -76,17 +78,9 @@ def build(self): Builds and returns all combinations of parameters specified by the param grid. """ - param_maps = [{}] - for (param, values) in self._param_grid.items(): - new_param_maps = [] - for value in values: - for old_map in param_maps: - copied_map = old_map.copy() - copied_map[param] = value - new_param_maps.append(copied_map) - param_maps = new_param_maps - - return param_maps + keys = self._param_grid.keys() + grid_values = self._param_grid.values() + return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] if __name__ == "__main__": From d08f9cfcc9cd2e4bb9e9ce83f885a722d302969f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 May 2015 11:56:05 -0700 Subject: [PATCH 2/2] simplify tests --- python/pyspark/ml/tuning.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 45ee862bff164..1773ab5bdcdb1 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -39,14 +39,10 @@ class ParamGridBuilder(object): {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] - >>> fail_count = 0 - >>> for e in expected: - ... if e not in output: - ... fail_count += 1 - >>> if len(expected) != len(output): - ... fail_count += 1 - >>> fail_count - 0 + >>> len(output) == len(expected) + True + >>> all([m in expected for m in output]) + True """ def __init__(self):