diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a56f061bef628..ce4cf3d5142b3 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -687,24 +687,15 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}, mode='permissive'): + def pipe(self, command, env={}, checkCode=False): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() [u'1', u'2', u'', u'3'] + + :param checkCode: whether or not to check the return value of the shell command. """ - if mode == 'permissive': - def fail_condition(x): - return False - elif mode == 'strict': - def fail_condition(x): - return x != 0 - elif mode == 'grep': - def fail_condition(x): - return x != 0 and x != 1 - else: - raise ValueError("mode must be one of 'permissive', 'strict' or 'grep'.") def func(iterator): pipe = Popen( @@ -719,7 +710,7 @@ def pipe_objs(out): def check_return_code(): pipe.wait() - if fail_condition(pipe.returncode): + if checkCode and pipe.returncode: raise Exception("Pipe function `%s' exited " "with error code %d" % (command, pipe.returncode)) else: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 42a14bf6dd292..46368c20d44bd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -879,13 +879,12 @@ def test_pipe_functions(self): rdd = self.sc.parallelize(data) with QuietTest(self.sc): self.assertEqual([], rdd.pipe('cc').collect()) - self.assertRaises(Py4JJavaError, rdd.pipe('cc', mode='strict').collect) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) result = rdd.pipe('cat').collect() result.sort() [self.assertEqual(x, y) for x, y in zip(data, result)] - self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', mode='strict').collect) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) - self.assertEqual([], rdd.pipe('grep 4', mode='grep').collect()) class ProfilerTests(PySparkTestCase):