Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5929][PYSPARK] Context addPyPackage and addPyRequirements #12398

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,19 @@ object SparkSubmit {
args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs
if (clusterManager != YARN) {
// The YARN backend distributes the primary file differently, so don't merge it.
args.files = mergeFileLists(args.files, args.primaryResource)
args.files = mergeFileLists(args.files, args.primaryResource, args.pyRequirements)
}
}
if (clusterManager != YARN) {
// The YARN backend handles python files differently, so don't merge the lists.
args.files = mergeFileLists(args.files, args.pyFiles)
args.files = mergeFileLists(args.files, args.pyFiles, args.pyRequirements)
}
if (args.pyFiles != null) {
sysProps("spark.submit.pyFiles") = args.pyFiles
}
if (args.pyRequirements != null) {
sysProps("spark.submit.pyRequirements") = args.pyRequirements
}
}

// In YARN mode for an R app, add the SparkR package archive and the R package
Expand Down Expand Up @@ -542,6 +545,10 @@ object SparkSubmit {
if (args.pyFiles != null) {
sysProps("spark.submit.pyFiles") = args.pyFiles
}

if (args.pyRequirements != null) {
sysProps("spark.submit.pyRequirements") = args.pyRequirements
}
}

// assure a keytab is available from any place in a JVM
Expand Down Expand Up @@ -593,6 +600,9 @@ object SparkSubmit {
if (args.pyFiles != null) {
sysProps("spark.submit.pyFiles") = args.pyFiles
}
if (args.pyRequirements != null) {
sysProps("spark.submit.pyRequirements") = args.pyRequirements
}
} else {
childArgs += (args.primaryResource, args.mainClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
var pyRequirements: String = null
var isR: Boolean = false
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
Expand Down Expand Up @@ -304,6 +305,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| numExecutors $numExecutors
| files $files
| pyFiles $pyFiles
| pyRequiremenst $pyRequirements
| archives $archives
| mainClass $mainClass
| primaryResource $primaryResource
Expand Down Expand Up @@ -395,6 +397,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
case PY_FILES =>
pyFiles = Utils.resolveURIs(value)

case PY_REQUIREMENTS =>
pyRequirements = Utils.resolveURIs(value)

case ARCHIVES =>
archives = Utils.resolveURIs(value)

Expand Down Expand Up @@ -505,6 +510,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| search for the maven coordinates given with --packages.
| --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place
| on the PYTHONPATH for Python apps.
| --py-requirements REQS Pip requirements file with dependencies that will be fetched
| and placed on PYTHONPATH
| --files FILES Comma-separated list of files to be placed in the working
| directory of each executor.
|
Expand Down
31 changes: 31 additions & 0 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,37 @@ class SparkSubmitSuite
appArgs.executorMemory should be ("2.3g")
}
}

test("py-requirements will be distributed") {
val pyReqs = "requirements.txt"

val clArgsYarn = Seq(
"--master", "yarn",
"--deploy-mode", "cluster",
"--py-requirements", pyReqs,
"mister.py"
)

val appArgsYarn = new SparkSubmitArguments(clArgsYarn)
val sysPropsYarn = SparkSubmit.prepareSubmitEnvironment(appArgsYarn)._3
appArgsYarn.pyRequirements should be (Utils.resolveURIs(pyReqs))
sysPropsYarn("spark.yarn.dist.files") should be (
PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(","))
sysPropsYarn("spark.submit.pyRequirements") should be (
PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(","))

val clArgs = Seq(
"--master", "local",
"--py-requirements", pyReqs,
"mister.py"
)

val appArgs = new SparkSubmitArguments(clArgs)
val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3
appArgs.pyRequirements should be (Utils.resolveURIs(pyReqs))
sysProps("spark.submit.pyRequirements") should be (
PythonRunner.formatPaths(Utils.resolveURIs(pyReqs)).mkString(","))
}
// scalastyle:on println

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SparkSubmitOptionParser {
protected final String PROPERTIES_FILE = "--properties-file";
protected final String PROXY_USER = "--proxy-user";
protected final String PY_FILES = "--py-files";
protected final String PY_REQUIREMENTS = "--py-requirements";
protected final String REPOSITORIES = "--repositories";
protected final String STATUS = "--status";
protected final String TOTAL_EXECUTOR_CORES = "--total-executor-cores";
Expand Down
73 changes: 67 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import shutil
import signal
import sys
import tarfile
import tempfile
import uuid
import threading
from threading import RLock
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -72,8 +75,8 @@ class SparkContext(object):
PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None, profiler_cls=BasicProfiler):
environment=None, batchSize=0, serializer=PickleSerializer(),
conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand Down Expand Up @@ -111,15 +114,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls)
self._do_init(master, appName, sparkHome, pyFiles, environment,
batchSize, serializer, conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise

def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls):
def _do_init(self, master, appName, sparkHome, pyFiles, environment,
batchSize, serializer, conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
Expand Down Expand Up @@ -206,6 +209,14 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._python_includes.append(filename)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

# Apply requirements file set by spark-submit.
for path in self._conf.get("spark.submit.pyRequirements", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
reqs_file = os.path.join(SparkFiles.getRootDirectory(), filename)
reqs = open(reqs_file).readlines()
self.addPyRequirements(reqs)

# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
Expand Down Expand Up @@ -814,6 +825,56 @@ def addPyFile(self, path):
import importlib
importlib.invalidate_caches()

def addPyPackage(self, pkg):
"""
Add a package to the spark context, the package must have already been
imported by the driver via __import__ semantics. Supports namespace
packages by simulating the loading __path__ as a set of modules from
the __path__ list in a single package. Example follows:

import pyspark
import foolib

sc = pyspark.SparkContext()
sc.addPyPackage(foolib)
# foolib now in workers PYTHONPATH
rdd = sc.parallelize([1, 2, 3])
doubles = rdd.map(lambda x: foolib.double(x))
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an example here?

tmp_dir = tempfile.mkdtemp()
try:
tar_path = os.path.join(tmp_dir, pkg.__name__+'.tar.gz')
tar = tarfile.open(tar_path, "w:gz")
for mod in pkg.__path__[::-1]:
# adds in reverse to simulate namespace loading path
tar.add(mod, arcname=os.path.basename(mod))
tar.close()
self.addPyFile(tar_path)
finally:
shutil.rmtree(tmp_dir)

def addPyRequirements(self, reqs):
"""
Add a list of pip requirements to distribute to workers.
The reqs list is composed of pip requirements strings.
See https://pip.pypa.io/en/latest/user_guide.html#requirements-files
Raises ImportError if the requirement can't be found. Example follows:

reqs = ['pkg1', 'pkg2', 'pkg3>=1.0,<=2.0']
sc.addPyRequirements(reqs)
// or load from requirements file
sc.addPyRequirements(open('requirements.txt').readlines())
"""
import pip
with tempfile.NamedTemporaryFile() as t:
t.write('\n'.join(reqs))
t.flush()
for req in pip.req.parse_requirements(t.name, session=uuid.uuid1()):
if not req.check_if_exists():
pip.main(['install', req.req.__str__()])
pkg = __import__(req.name)
self.addPyPackage(pkg)

def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
Expand Down
27 changes: 26 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from array import array
from glob import glob
import os
import os.path
import re
import shutil
import subprocess
Expand Down Expand Up @@ -57,7 +58,6 @@
else:
from StringIO import StringIO


from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
Expand Down Expand Up @@ -1947,6 +1947,31 @@ def test_with_stop(self):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)

def test_add_py_package(self):
name = "test_tmp"
try:
os.mkdir(name)
with open(os.path.join(name, "__init__.py"), 'w+') as temp:
temp.write("triple = lambda x: 3*x")
pkg = __import__(name)
with SparkContext() as sc:
# trips = sc.parallelize([0, 1, 2, 3]).map(pkg.triple)
# sc.addPyPackage(pkg)
trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x))
self.assertSequenceEqual([0, 3, 6, 9], trips.collect())
finally:
shutil.rmtree(name)

def test_add_py_requirements(self):
import pip
reqs = ['requests', 'quadkey>=0.0.5', 'six==1.8.0']
with SparkContext() as sc:
sc.addPyRequirements(reqs)
import quadkey
qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \
.map(lambda pair: quadkey.from_geo(pair, 1).key)
self.assertSequenceEqual(['3', '1', '1'], qks.collect())

def test_progress_api(self):
with SparkContext() as sc:
sc.setJobGroup('test_progress_api', '', True)
Expand Down