Skip to content

Commit

Permalink
addRequirementsFile -> addPyRequirements
Browse files Browse the repository at this point in the history
  • Loading branch information
buck heroux committed May 13, 2016
1 parent 1d5d25f commit f4af842
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
31 changes: 20 additions & 11 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment,
for path in self._conf.get("spark.submit.pyRequirements", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
self.addRequirementsFile(os.path.join(SparkFiles.getRootDirectory(), filename))
reqs_file = os.path.join(SparkFiles.getRootDirectory(), filename)
reqs = open(reqs_file).readlines()
self.addPyRequirements(reqs)

# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
Expand Down Expand Up @@ -851,20 +853,27 @@ def addPyPackage(self, pkg):
finally:
shutil.rmtree(tmp_dir)

def addRequirementsFile(self, path):
def addPyRequirements(self, reqs):
"""
Add a pip requirements file to distribute dependencies for all tasks
on thie SparkContext in the future. An ImportError will be thrown if
a module in the file can't be downloaded.
Add a list of pip requirements to distribute to workers.
The reqs list is composed of pip requirements strings.
See https://pip.pypa.io/en/latest/user_guide.html#requirements-files
Raises ImportError if the requirement can't be found
Raises ImportError if the requirement can't be found. Example follows:
reqs = ['pkg1', 'pkg2', 'pkg3>=1.0,<=2.0']
sc.addPyRequirements(reqs)
// or load from requirements file
sc.addPyRequirements(open('requirements.txt').readlines())
"""
import pip
for req in pip.req.parse_requirements(path, session=uuid.uuid1()):
if not req.check_if_exists():
pip.main(['install', req.req.__str__()])
pkg = __import__(req.name)
self.addPyPackage(pkg)
with tempfile.NamedTemporaryFile() as t:
t.write('\n'.join(reqs))
t.flush()
for req in pip.req.parse_requirements(t.name, session=uuid.uuid1()):
if not req.check_if_exists():
pip.main(['install', req.req.__str__()])
pkg = __import__(req.name)
self.addPyPackage(pkg)

def setCheckpointDir(self, dirName):
"""
Expand Down
19 changes: 9 additions & 10 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,23 +1955,22 @@ def test_add_py_package(self):
temp.write("triple = lambda x: 3*x")
pkg = __import__(name)
with SparkContext() as sc:
#trips = sc.parallelize([0, 1, 2, 3]).map(test_tmp.triple)
#trips = sc.parallelize([0, 1, 2, 3]).map(pkg.triple)
#sc.addPyPackage(pkg)
trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x))
self.assertSequenceEqual([0, 3, 6, 9], trips.collect())
finally:
shutil.rmtree(name)

def test_requirements_file(self):
def test_add_py_requirements(self):
import pip
with tempfile.NamedTemporaryFile() as temp:
temp.write('simplejson\nquadkey>=0.0.5\nsix==1.8.0')
with SparkContext() as sc:
sc.addRequirementsFile(temp.name)
import quadkey
qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \
.map(lambda pair: quadkey.from_geo(pair, 1).key)
self.assertSequenceEqual(['3', '1', '1'], qks.collect())
reqs = ['requests', 'quadkey>=0.0.5', 'six==1.8.0']
with SparkContext() as sc:
sc.addPyRequirements(reqs)
import quadkey
qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \
.map(lambda pair: quadkey.from_geo(pair, 1).key)
self.assertSequenceEqual(['3', '1', '1'], qks.collect())

def test_progress_api(self):
with SparkContext() as sc:
Expand Down

0 comments on commit f4af842

Please sign in to comment.