Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def _getOrCreate(cls):
cls._taskContext = TaskContext()
return cls._taskContext

@classmethod
def _setTaskContext(cls, taskContext):
cls._taskContext = taskContext

@classmethod
def get(cls):
"""
Expand Down Expand Up @@ -162,7 +166,10 @@ def get(cls):
running tasks.

.. note:: Must be called on the worker, not the driver. Returns None if not initialized.
An Exception will raise if it is not in a barrier stage.
"""
if not isinstance(cls._taskContext, BarrierTaskContext):
raise Exception('It is not in a barrier stage')
return cls._taskContext

@classmethod
Expand Down
85 changes: 85 additions & 0 deletions python/pyspark/tests/test_taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
from pyspark.testing.utils import PySparkTestCase, SPARK_HOME

if sys.version_info[0] >= 3:
xrange = range


class TaskContextTests(PySparkTestCase):

Expand Down Expand Up @@ -146,6 +149,49 @@ def f(iterator):
self.assertTrue(len(taskInfos) == 4)
self.assertTrue(len(taskInfos[0]) == 4)

def test_context_get(self):
"""
Verify that TaskContext.get() works both in or not in a barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
taskContext = TaskContext.get()
if isinstance(taskContext, BarrierTaskContext):
yield taskContext.partitionId() + 1
elif isinstance(taskContext, TaskContext):
yield taskContext.partitionId() + 2
else:
yield -1

# for normal stage
result1 = rdd.mapPartitions(f).collect()
self.assertTrue(result1 == [2, 3, 4, 5])
# for barrier stage
result2 = rdd.barrier().mapPartitions(f).collect()
self.assertTrue(result2 == [1, 2, 3, 4])

def test_barrier_context_get(self):
"""
Verify that BarrierTaskContext.get() should only works in a barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
try:
taskContext = BarrierTaskContext.get()
except Exception:
yield -1
else:
yield taskContext.partitionId()

# for normal stage
result1 = rdd.mapPartitions(f).collect()
self.assertTrue(result1 == [-1, -1, -1, -1])
# for barrier stage
result2 = rdd.barrier().mapPartitions(f).collect()
self.assertTrue(result2 == [0, 1, 2, 3])


class TaskContextTestsWithWorkerReuse(unittest.TestCase):

Expand Down Expand Up @@ -181,6 +227,45 @@ def context_barrier(x):
for pid in pids:
self.assertTrue(pid in worker_pids)

def test_task_context_correct_with_python_worker_reuse(self):
"""Verify the task context correct when reused python worker"""
# start a normal job first to start all workers and get all worker pids
worker_pids = self.sc.parallelize(xrange(2), 2).map(lambda x: os.getpid()).collect()
# the worker will reuse in this barrier job
rdd = self.sc.parallelize(xrange(10), 2)

def context(iterator):
tp = TaskContext.get().partitionId()
try:
bp = BarrierTaskContext.get().partitionId()
except Exception:
bp = -1

yield (tp, bp, os.getpid())

# normal stage after normal stage
normal_result = rdd.mapPartitions(context).collect()
tps, bps, pids = zip(*normal_result)
print(tps)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (-1, -1))
for pid in pids:
self.assertTrue(pid in worker_pids)
# barrier stage after normal stage
barrier_result = rdd.barrier().mapPartitions(context).collect()
tps, bps, pids = zip(*barrier_result)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (0, 1))
for pid in pids:
self.assertTrue(pid in worker_pids)
# normal stage after barrier stage
normal_result2 = rdd.mapPartitions(context).collect()
tps, bps, pids = zip(*normal_result2)
self.assertTrue(tps == (0, 1))
self.assertTrue(bps == (-1, -1))
for pid in pids:
self.assertTrue(pid in worker_pids)

def tearDown(self):
self.sc.stop()

Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ def main(infile, outfile):
if isBarrier:
taskContext = BarrierTaskContext._getOrCreate()
BarrierTaskContext._initialize(boundPort, secret)
# Set the task context instance here, so we can get it by TaskContext.get for
# both TaskContext and BarrierTaskContext
TaskContext._setTaskContext(taskContext)
else:
taskContext = TaskContext._getOrCreate()
# read inputs for TaskContext info
Expand Down Expand Up @@ -596,6 +599,11 @@ def process():
profiler.profile(process)
else:
process()

# Reset task context to None. This is a guard code to avoid residual context when worker
# reuse.
TaskContext._setTaskContext(None)
BarrierTaskContext._setTaskContext(None)
Copy link
Member

@HyukjinKwon HyukjinKwon Oct 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what happens if it fails with exceptions in the middle of execution in this worker?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really needed? We always set the global TaskContext and never reset it previouslly.

Copy link
Contributor Author

@ConeyLiu ConeyLiu Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what happens if it fails with exceptions in the middle of execution in this worker?

If exceptions occured, the worker will be closed with sys.exit(-1).

Is it really needed? We always set the global TaskContext and never reset it previouslly.

Previously:

val rdd = ...
val barriered = rdd.barrier().mapPartitions(...)

barriered.mapPartitions(...)  # here the BarrierTaskContext still existed.

This code is just a guard program, it shouldn't increase extra overhead or behavior change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is just a guard program, it shouldn't increase extra overhead or behavior change.

I guess that's only when the worker is reused. Can you clarify it with comments here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reviewing. updated.

except Exception:
try:
exc_info = traceback.format_exc()
Expand Down