Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5929] Pyspark: Register a pip requirements file with spark_context #4743

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class SparkContext(object):
_python_includes = None # zip and egg files that need to be added to PYTHONPATH

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None, profiler_cls=BasicProfiler):
requirementsFile=None, environment=None, batchSize=0,
serializer=PickleSerializer(), conf=None, gateway=None,
jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand All @@ -78,6 +79,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
:param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
system or HDFS, HTTP, HTTPS, or FTP URLs.
:param requirementsFile: Pip requirements file to send dependencies
to the cluster and add to PYTHONPATH.
:param environment: A dictionary of environment variables to set on
worker nodes.
:param batchSize: The number of Python objects represented as a single
Expand All @@ -104,15 +107,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls)
self._do_init(master, appName, sparkHome, pyFiles, requirementsFile, environment,
batchSize, serializer, conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise

def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls):
def _do_init(self, master, appName, sparkHome, pyFiles, requirementsFile, environment,
batchSize, serializer, conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
Expand Down Expand Up @@ -180,6 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in (pyFiles or []):
self.addPyFile(path)

# Deplpoy code dependencies from requirements file in the constructor
if requirementsFile:
self.addRequirementsFile(requirementsFile)

# Deploy code dependencies set by spark-submit; these will already have been added
# with SparkContext.addFile, so we just need to add them to the PYTHONPATH
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
Expand Down Expand Up @@ -710,6 +717,34 @@ def addPyFile(self, path):
# for tests in local mode
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

def addRequirementsFile(self, path):
"""
Add a pip requirements file to distribute dependencies for all tasks
on thie SparkContext in the future. An ImportError will be thrown if
a module in the file can't be downloaded.
See https://pip.pypa.io/en/latest/user_guide.html#requirements-files
"""
import importlib
import pip
import tarfile
import tempfile
import uuid
tar_dir = tempfile.mkdtemp()
for req in pip.req.parse_requirements(path, session=uuid.uuid1()):
if not req.check_if_exists():
pip.main(['install', req.req.__str__()])
try:
mod = importlib.import_module(req.name)
finally:
shutil.rmtree(tar_dir)
mod_path = mod.__path__[0]
tar_path = os.path.join(tar_dir, req.name+'.tar.gz')
tar = tarfile.open(tar_path, "w:gz")
tar.add(mod_path, arcname=os.path.basename(mod_path))
tar.close()
self.addPyFile(tar_path)
shutil.rmtree(tar_dir)

def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
Expand Down