Skip to content

Commit

Permalink
address comment, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 13, 2014
1 parent 0a5b6eb commit 4f8309d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def addInPlace(self, value1, value2):
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)


class StatsParam(AccumulatorParam):
"""StatsParam is used to merge pstats.Stats"""
class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""

@staticmethod
def zero(value):
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
Expand All @@ -35,7 +34,7 @@
from random import Random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import StatsParam
from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, CompressedSerializer
Expand Down Expand Up @@ -2114,7 +2113,7 @@ def _jrdd(self):
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, StatsParam) if enable_profile else None
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
ser = CloudPickleSerializer()
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,31 @@ def test_repartitionAndSortWithinPartitions(self):
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])


class TestProfiler(PySparkTestCase):

def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)

def test_profiler(self):

def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100)).foreach(heavy_foo)
from pyspark.rdd import PipelinedRDD
profiles = PipelinedRDD._created_profiles
self.assertEqual(1, len(profiles))
id, acc = profiles.pop()
stats = acc.value
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
self.assertTrue("heavy_foo" in func_names)


class TestSQL(PySparkTestCase):

def setUp(self):
Expand Down

0 comments on commit 4f8309d

Please sign in to comment.