diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 84b026d956a6b..4ec4646411c7b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -213,7 +213,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, for path in self._conf.get("spark.submit.pyRequirements", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self.addRequirementsFile(os.path.join(SparkFiles.getRootDirectory(), filename)) + reqs_file = os.path.join(SparkFiles.getRootDirectory(), filename) + reqs = open(reqs_file).readlines() + self.addPyRequirements(reqs) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -851,20 +853,27 @@ def addPyPackage(self, pkg): finally: shutil.rmtree(tmp_dir) - def addRequirementsFile(self, path): + def addPyRequirements(self, reqs): """ - Add a pip requirements file to distribute dependencies for all tasks - on thie SparkContext in the future. An ImportError will be thrown if - a module in the file can't be downloaded. + Add a list of pip requirements to distribute to workers. + The reqs list is composed of pip requirements strings. See https://pip.pypa.io/en/latest/user_guide.html#requirements-files - Raises ImportError if the requirement can't be found + Raises ImportError if the requirement can't be found. Example follows: + + reqs = ['pkg1', 'pkg2', 'pkg3>=1.0,<=2.0'] + sc.addPyRequirements(reqs) + // or load from requirements file + sc.addPyRequirements(open('requirements.txt').readlines()) """ import pip - for req in pip.req.parse_requirements(path, session=uuid.uuid1()): - if not req.check_if_exists(): - pip.main(['install', req.req.__str__()]) - pkg = __import__(req.name) - self.addPyPackage(pkg) + with tempfile.NamedTemporaryFile() as t: + t.write('\n'.join(reqs)) + t.flush() + for req in pip.req.parse_requirements(t.name, session=uuid.uuid1()): + if not req.check_if_exists(): + pip.main(['install', req.req.__str__()]) + pkg = __import__(req.name) + self.addPyPackage(pkg) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e9d35154e9d56..a6ba9f5cf257a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1955,23 +1955,22 @@ def test_add_py_package(self): temp.write("triple = lambda x: 3*x") pkg = __import__(name) with SparkContext() as sc: - #trips = sc.parallelize([0, 1, 2, 3]).map(test_tmp.triple) + #trips = sc.parallelize([0, 1, 2, 3]).map(pkg.triple) #sc.addPyPackage(pkg) trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x)) self.assertSequenceEqual([0, 3, 6, 9], trips.collect()) finally: shutil.rmtree(name) - def test_requirements_file(self): + def test_add_py_requirements(self): import pip - with tempfile.NamedTemporaryFile() as temp: - temp.write('simplejson\nquadkey>=0.0.5\nsix==1.8.0') - with SparkContext() as sc: - sc.addRequirementsFile(temp.name) - import quadkey - qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \ - .map(lambda pair: quadkey.from_geo(pair, 1).key) - self.assertSequenceEqual(['3', '1', '1'], qks.collect()) + reqs = ['requests', 'quadkey>=0.0.5', 'six==1.8.0'] + with SparkContext() as sc: + sc.addPyRequirements(reqs) + import quadkey + qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \ + .map(lambda pair: quadkey.from_geo(pair, 1).key) + self.assertSequenceEqual(['3', '1', '1'], qks.collect()) def test_progress_api(self): with SparkContext() as sc: