Skip to content

Commit

Permalink
[SPARK-2470] PEP8 fixes to rddsampler.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nchammas committed Jul 20, 2014
1 parent 4dd148f commit f0a7ebf
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
import sys
import random


class RDDSampler(object):
def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
except ImportError:
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
print >> sys.stderr, (
"NumPy does not appear to be installed. "
"Falling back to default random generator for sampling.")
self._use_numpy = False

self._seed = seed if seed is not None else random.randint(0, sys.maxint)
Expand Down Expand Up @@ -61,7 +64,7 @@ def getUniformSample(self, split):
def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
self.initRandomGenerator(split)

if self._use_numpy:
return self._random.poisson(mean)
else:
Expand All @@ -80,30 +83,27 @@ def getPoissonSample(self, split, mean):
num_arrivals += 1

return (num_arrivals - 1)

def shuffle(self, vals):
if self._random is None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter

if self._use_numpy:
self._random.shuffle(vals)
else:
self._random.shuffle(vals, self._random.random)

def func(self, split, iterator):
if self._withReplacement:
if self._withReplacement:
for obj in iterator:
# For large datasets, the expected number of occurrences of each element in a sample with
# replacement is Poisson(frac). We use that to get a count for each element.
count = self.getPoissonSample(split, mean = self._fraction)
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
count = self.getPoissonSample(split, mean=self._fraction)
for _ in range(0, count):
yield obj
else:
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
yield obj




0 comments on commit f0a7ebf

Please sign in to comment.