-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5929][PYSPARK] Context addPyPackage and addPyRequirements #12398
Changes from 19 commits
0ed060d
6b8bcde
2773483
0371ad9
f2a46e5
fca4be6
d287522
76ff637
565bf7f
23771fd
cd21c5c
39f26d9
49a4ed0
1501d0f
3af35bb
88a1d6c
93b9e9f
82476a6
ce9966e
82534d0
ea6b89f
1d5d25f
f4af842
9c37e06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,9 @@ | |
import shutil | ||
import signal | ||
import sys | ||
import tarfile | ||
import tempfile | ||
import uuid | ||
import threading | ||
from threading import RLock | ||
from tempfile import NamedTemporaryFile | ||
|
@@ -72,8 +75,8 @@ class SparkContext(object): | |
PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') | ||
|
||
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, | ||
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, | ||
gateway=None, jsc=None, profiler_cls=BasicProfiler): | ||
environment=None, batchSize=0, serializer=PickleSerializer(), | ||
conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler): | ||
""" | ||
Create a new SparkContext. At least the master and app name should be set, | ||
either through the named parameters here or through C{conf}. | ||
|
@@ -111,15 +114,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, | |
self._callsite = first_spark_call() or CallSite(None, None, None) | ||
SparkContext._ensure_initialized(self, gateway=gateway) | ||
try: | ||
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, | ||
conf, jsc, profiler_cls) | ||
self._do_init(master, appName, sparkHome, pyFiles, environment, | ||
batchSize, serializer, conf, jsc, profiler_cls) | ||
except: | ||
# If an error occurs, clean up in order to allow future SparkContext creation: | ||
self.stop() | ||
raise | ||
|
||
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, | ||
conf, jsc, profiler_cls): | ||
def _do_init(self, master, appName, sparkHome, pyFiles, environment, | ||
batchSize, serializer, conf, jsc, profiler_cls): | ||
self.environment = environment or {} | ||
self._conf = conf or SparkConf(_jvm=self._jvm) | ||
self._batchSize = batchSize # -1 represents an unlimited batch size | ||
|
@@ -814,6 +817,40 @@ def addPyFile(self, path): | |
import importlib | ||
importlib.invalidate_caches() | ||
|
||
def addPyPackage(self, pkg): | ||
""" | ||
Add a package to the spark context, the package must have already been | ||
imported by the driver via __import__ semantics. Supports namespace | ||
packages by simulating the loading __path__ as a set of modules from | ||
the __path__ list in a single package. | ||
""" | ||
tmp_dir = tempfile.mkdtemp() | ||
try: | ||
tar_path = os.path.join(tmp_dir, pkg.__name__+'.tar.gz') | ||
tar = tarfile.open(tar_path, "w:gz") | ||
for mod in pkg.__path__[::-1]: | ||
# adds in reverse to simulate namespace loading path | ||
tar.add(mod, arcname=os.path.basename(mod)) | ||
tar.close() | ||
self.addPyFile(tar_path) | ||
finally: | ||
shutil.rmtree(tmp_dir) | ||
|
||
def addRequirementsFile(self, path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will it be better to pass the requirements as string? Then you can easily keep the requirements together with Python source code. Even you have a txt file, it's as easy as:
|
||
""" | ||
Add a pip requirements file to distribute dependencies for all tasks | ||
on thie SparkContext in the future. An ImportError will be thrown if | ||
a module in the file can't be downloaded. | ||
See https://pip.pypa.io/en/latest/user_guide.html#requirements-files | ||
Raises ImportError if the requirement can't be found | ||
""" | ||
import pip | ||
for req in pip.req.parse_requirements(path, session=uuid.uuid1()): | ||
if not req.check_if_exists(): | ||
pip.main(['install', req.req.__str__()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it seems that this can sometimesrequire elevated privileges based on the issues with the previous jenkins run. What about if at startup we created a fixed temp directory per context adding it to our path with |
||
pkg = __import__(req.name) | ||
self.addPyPackage(pkg) | ||
|
||
def setCheckpointDir(self, dirName): | ||
""" | ||
Set the directory under which RDDs are going to be checkpointed. The | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from array import array | ||
from glob import glob | ||
import os | ||
import os.path | ||
import re | ||
import shutil | ||
import subprocess | ||
|
@@ -57,7 +58,6 @@ | |
else: | ||
from StringIO import StringIO | ||
|
||
|
||
from pyspark.conf import SparkConf | ||
from pyspark.context import SparkContext | ||
from pyspark.rdd import RDD | ||
|
@@ -1947,6 +1947,33 @@ def test_with_stop(self): | |
sc.stop() | ||
self.assertEqual(SparkContext._active_spark_context, None) | ||
|
||
def test_add_py_package(self): | ||
name = "test_tmp" | ||
try: | ||
os.mkdir(name) | ||
with open(os.path.join(name, "__init__.py"), 'w+') as temp: | ||
temp.write("triple = lambda x: 3*x") | ||
pkg = __import__(name) | ||
with SparkContext() as sc: | ||
#trips = sc.parallelize([0, 1, 2, 3]).map(test_tmp.triple) | ||
#sc.addPyPackage(pkg) | ||
trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x)) | ||
self.assertSequenceEqual([0, 3, 6, 9], trips.collect()) | ||
finally: | ||
shutil.rmtree(name) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the extra empty line |
||
|
||
def test_requirements_file(self): | ||
import pip | ||
with tempfile.NamedTemporaryFile() as temp: | ||
temp.write('simplejson\nquadkey>=0.0.5\nsix==1.8.0') | ||
with SparkContext() as sc: | ||
sc.addRequirementsFile(temp.name) | ||
import quadkey | ||
qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \ | ||
.map(lambda pair: quadkey.from_geo(pair, 1).key) | ||
self.assertSequenceEqual(['3', '1', '1'], qks.collect()) | ||
|
||
def test_progress_api(self): | ||
with SparkContext() as sc: | ||
sc.setJobGroup('test_progress_api', '', True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add an example here?