Skip to content

Commit

Permalink
Added
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed May 26, 2018
1 parent 1b1528a commit ac0270a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala
dataOut.writeInt(localProps.size)
localProps.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}

// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class TaskContext(object):
_partitionId = None
_stageId = None
_taskAttemptId = None
_localProperties = None

def __new__(cls):
"""Even if users construct TaskContext instead of using get, give them the singleton."""
Expand Down Expand Up @@ -88,3 +89,9 @@ def taskAttemptId(self):
TaskAttemptID.
"""
return self._taskAttemptId

def getLocalProperty(self, key):
"""
Get a local property set upstream in the driver, or None if it is missing.
"""
return self._localProperties[key]
8 changes: 8 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ def test_tc_on_driver(self):
tc = TaskContext.get()
self.assertTrue(tc is None)

def test_get_local_property(self):
"""Verify that local properties set on the driver are available in TaskContext."""
self.sc.setLocalProperty("testkey", "testvalue")
rdd = self.sc.parallelize(range(1), 1)
prop1 = rdd1.map(lambda x: TaskContext.get().getLocalProperty("testkey")).collect()[0]
self.assertEqual(prop1, "testkey")
prop2 = rdd1.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
self.assertTrue(prop2 is None)

class RDDTests(ReusedPySparkTestCase):

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ def main(infile, outfile):
taskContext._partitionId = read_int(infile)
taskContext._attemptNumber = read_int(infile)
taskContext._taskAttemptId = read_long(infile)
taskContext._localProperties = dict()
for i in range(read_int(infile)):
k = utf8_deserializer.loads(infile)
v = utf8_deserializer.loads(infile)
taskContext._localProperties[k] = v

shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()
Expand Down

0 comments on commit ac0270a

Please sign in to comment.