diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 926e1ff7a874d..52d2ef6e700e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -356,16 +356,19 @@ object SparkSubmit { args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs if (clusterManager != YARN) { // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.primaryResource, args.pyRequirements) } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. - args.files = mergeFileLists(args.files, args.pyFiles) + args.files = mergeFileLists(args.files, args.pyFiles, args.pyRequirements) } if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + if (args.pyRequirements != null) { + sysProps("spark.submit.pyRequirements") = args.pyRequirements + } } // In YARN mode for an R app, add the SparkR package archive and the R package @@ -542,6 +545,10 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + + if (args.pyRequirements != null) { + sysProps("spark.submit.pyRequirements") = args.pyRequirements + } } // assure a keytab is available from any place in a JVM @@ -593,6 +600,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + if (args.pyRequirements != null) { + sysProps("spark.submit.pyRequirements") = args.pyRequirements + } } else { childArgs += (args.primaryResource, args.mainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index ec6d48485f110..3136a729d5933 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -64,6 +64,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var pyRequirements: String = null var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() @@ -304,6 +305,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | numExecutors $numExecutors | files $files | pyFiles $pyFiles + | pyRequiremenst $pyRequirements | archives $archives | mainClass $mainClass | primaryResource $primaryResource @@ -395,6 +397,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PY_FILES => pyFiles = Utils.resolveURIs(value) + case PY_REQUIREMENTS => + pyRequirements = Utils.resolveURIs(value) + case ARCHIVES => archives = Utils.resolveURIs(value) @@ -505,6 +510,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. + | --py-requirements REQS Pip requirements file with dependencies that will be fetched + | and placed on PYTHONPATH | --files FILES Comma-separated list of files to be placed in the working | directory of each executor. | diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 271897699201b..8fac39a5cb267 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -570,6 +570,37 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + + test("py-requirements will be distributed") { + val pyReqs = "requirements.txt" + + val clArgsYarn = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--py-requirements", pyReqs, + "mister.py" + ) + + val appArgsYarn = new SparkSubmitArguments(clArgsYarn) + val sysPropsYarn = SparkSubmit.prepareSubmitEnvironment(appArgsYarn)._3 + appArgsYarn.pyRequirements should be (Utils.resolveURIs(pyReqs)) + sysPropsYarn("spark.yarn.dist.files") should be ( + PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(",")) + sysPropsYarn("spark.submit.pyRequirements") should be ( + PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(",")) + + val clArgs = Seq( + "--master", "local", + "--py-requirements", pyReqs, + "mister.py" + ) + + val appArgs = new SparkSubmitArguments(clArgs) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + appArgs.pyRequirements should be (Utils.resolveURIs(pyReqs)) + sysProps("spark.submit.pyRequirements") should be ( + PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(",")) + } // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 6767cc5079649..d036ac322809c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -55,6 +55,7 @@ class SparkSubmitOptionParser { protected final String PROPERTIES_FILE = "--properties-file"; protected final String PROXY_USER = "--proxy-user"; protected final String PY_FILES = "--py-files"; + protected final String PY_REQUIREMENTS = "--py-requirements"; protected final String REPOSITORIES = "--repositories"; protected final String STATUS = "--status"; protected final String TOTAL_EXECUTOR_CORES = "--total-executor-cores"; diff --git a/python/pyspark/context.py b/python/pyspark/context.py index cb15b4b91f913..acd6ca1d5e7a0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,9 @@ import shutil import signal import sys +import tarfile +import tempfile +import uuid import threading from threading import RLock from tempfile import NamedTemporaryFile @@ -72,8 +75,8 @@ class SparkContext(object): PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None, profiler_cls=BasicProfiler): + environment=None, batchSize=0, serializer=PickleSerializer(), + conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -111,15 +114,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: - self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc, profiler_cls) + self._do_init(master, appName, sparkHome, pyFiles, environment, + batchSize, serializer, conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise - def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc, profiler_cls): + def _do_init(self, master, appName, sparkHome, pyFiles, environment, + batchSize, serializer, conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -206,6 +209,14 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._python_includes.append(filename) sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + # Apply requirements file set by spark-submit. + for path in self._conf.get("spark.submit.pyRequirements", "").split(","): + if path != "": + (dirname, filename) = os.path.split(path) + reqs_file = os.path.join(SparkFiles.getRootDirectory(), filename) + reqs = open(reqs_file).readlines() + self.addPyRequirements(reqs) + # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) self._temp_dir = \ @@ -814,6 +825,56 @@ def addPyFile(self, path): import importlib importlib.invalidate_caches() + def addPyPackage(self, pkg): + """ + Add a package to the spark context, the package must have already been + imported by the driver via __import__ semantics. Supports namespace + packages by simulating the loading __path__ as a set of modules from + the __path__ list in a single package. Example follows: + + import pyspark + import foolib + + sc = pyspark.SparkContext() + sc.addPyPackage(foolib) + # foolib now in workers PYTHONPATH + rdd = sc.parallelize([1, 2, 3]) + doubles = rdd.map(lambda x: foolib.double(x)) + """ + tmp_dir = tempfile.mkdtemp() + try: + tar_path = os.path.join(tmp_dir, pkg.__name__+'.tar.gz') + tar = tarfile.open(tar_path, "w:gz") + for mod in pkg.__path__[::-1]: + # adds in reverse to simulate namespace loading path + tar.add(mod, arcname=os.path.basename(mod)) + tar.close() + self.addPyFile(tar_path) + finally: + shutil.rmtree(tmp_dir) + + def addPyRequirements(self, reqs): + """ + Add a list of pip requirements to distribute to workers. + The reqs list is composed of pip requirements strings. + See https://pip.pypa.io/en/latest/user_guide.html#requirements-files + Raises ImportError if the requirement can't be found. Example follows: + + reqs = ['pkg1', 'pkg2', 'pkg3>=1.0,<=2.0'] + sc.addPyRequirements(reqs) + // or load from requirements file + sc.addPyRequirements(open('requirements.txt').readlines()) + """ + import pip + with tempfile.NamedTemporaryFile() as t: + t.write('\n'.join(reqs)) + t.flush() + for req in pip.req.parse_requirements(t.name, session=uuid.uuid1()): + if not req.check_if_exists(): + pip.main(['install', req.req.__str__()]) + pkg = __import__(req.name) + self.addPyPackage(pkg) + def setCheckpointDir(self, dirName): """ Set the directory under which RDDs are going to be checkpointed. The diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 97ea39dde05fa..8c8070ba75112 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,6 +23,7 @@ from array import array from glob import glob import os +import os.path import re import shutil import subprocess @@ -57,7 +58,6 @@ else: from StringIO import StringIO - from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -1947,6 +1947,31 @@ def test_with_stop(self): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_add_py_package(self): + name = "test_tmp" + try: + os.mkdir(name) + with open(os.path.join(name, "__init__.py"), 'w+') as temp: + temp.write("triple = lambda x: 3*x") + pkg = __import__(name) + with SparkContext() as sc: + # trips = sc.parallelize([0, 1, 2, 3]).map(pkg.triple) + # sc.addPyPackage(pkg) + trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x)) + self.assertSequenceEqual([0, 3, 6, 9], trips.collect()) + finally: + shutil.rmtree(name) + + def test_add_py_requirements(self): + import pip + reqs = ['requests', 'quadkey>=0.0.5', 'six==1.8.0'] + with SparkContext() as sc: + sc.addPyRequirements(reqs) + import quadkey + qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \ + .map(lambda pair: quadkey.from_geo(pair, 1).key) + self.assertSequenceEqual(['3', '1', '1'], qks.collect()) + def test_progress_api(self): with SparkContext() as sc: sc.setJobGroup('test_progress_api', '', True)