Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5929][PYSPARK] Context addPyPackage and addPyRequirements #12398

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import shutil
import signal
import sys
import tarfile
import tempfile
import uuid
import threading
from threading import RLock
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -72,8 +75,8 @@ class SparkContext(object):
PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None, profiler_cls=BasicProfiler):
environment=None, batchSize=0, serializer=PickleSerializer(),
conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand Down Expand Up @@ -111,15 +114,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
self._callsite = first_spark_call() or CallSite(None, None, None)
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls)
self._do_init(master, appName, sparkHome, pyFiles, environment,
batchSize, serializer, conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise

def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc, profiler_cls):
def _do_init(self, master, appName, sparkHome, pyFiles, environment,
batchSize, serializer, conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
Expand Down Expand Up @@ -814,6 +817,40 @@ def addPyFile(self, path):
import importlib
importlib.invalidate_caches()

def addPyPackage(self, pkg):
"""
Add a package to the spark context, the package must have already been
imported by the driver via __import__ semantics. Supports namespace
packages by simulating the loading __path__ as a set of modules from
the __path__ list in a single package.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an example here?

tmp_dir = tempfile.mkdtemp()
try:
tar_path = os.path.join(tmp_dir, pkg.__name__+'.tar.gz')
tar = tarfile.open(tar_path, "w:gz")
for mod in pkg.__path__[::-1]:
# adds in reverse to simulate namespace loading path
tar.add(mod, arcname=os.path.basename(mod))
tar.close()
self.addPyFile(tar_path)
finally:
shutil.rmtree(tmp_dir)

def addRequirementsFile(self, path):
Copy link
Contributor

@davies davies May 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be better to pass the requirements as string? Then you can easily keep the requirements together with Python source code.

Even you have a txt file, it's as easy as:

sc.addRequirements(open(path).read())

"""
Add a pip requirements file to distribute dependencies for all tasks
on thie SparkContext in the future. An ImportError will be thrown if
a module in the file can't be downloaded.
See https://pip.pypa.io/en/latest/user_guide.html#requirements-files
Raises ImportError if the requirement can't be found
"""
import pip
for req in pip.req.parse_requirements(path, session=uuid.uuid1()):
if not req.check_if_exists():
pip.main(['install', req.req.__str__()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that this can sometimesrequire elevated privileges based on the issues with the previous jenkins run. What about if at startup we created a fixed temp directory per context adding it to our path with sys.path.insert(0, self.pipBase) and at install did something along the lines of:
pip.main(['install', req.req.__str__(), '--target', self.pipBase]) so that we don't have to have write permissions to the default pip target?

pkg = __import__(req.name)
self.addPyPackage(pkg)

def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
Expand Down
29 changes: 28 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from array import array
from glob import glob
import os
import os.path
import re
import shutil
import subprocess
Expand Down Expand Up @@ -57,7 +58,6 @@
else:
from StringIO import StringIO


from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
Expand Down Expand Up @@ -1947,6 +1947,33 @@ def test_with_stop(self):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)

def test_add_py_package(self):
name = "test_tmp"
try:
os.mkdir(name)
with open(os.path.join(name, "__init__.py"), 'w+') as temp:
temp.write("triple = lambda x: 3*x")
pkg = __import__(name)
with SparkContext() as sc:
#trips = sc.parallelize([0, 1, 2, 3]).map(test_tmp.triple)
#sc.addPyPackage(pkg)
trips = sc.parallelize([0, 1, 2, 3]).map(lambda x: pkg.triple(x))
self.assertSequenceEqual([0, 3, 6, 9], trips.collect())
finally:
shutil.rmtree(name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the extra empty line


def test_requirements_file(self):
import pip
with tempfile.NamedTemporaryFile() as temp:
temp.write('simplejson\nquadkey>=0.0.5\nsix==1.8.0')
with SparkContext() as sc:
sc.addRequirementsFile(temp.name)
import quadkey
qks = sc.parallelize([(0, 0), (1, 1), (2, 2)]) \
.map(lambda pair: quadkey.from_geo(pair, 1).key)
self.assertSequenceEqual(['3', '1', '1'], qks.collect())

def test_progress_api(self):
with SparkContext() as sc:
sc.setJobGroup('test_progress_api', '', True)
Expand Down