Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 13, 2014
1 parent f866bcf commit 4dfa2cd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
Expand Down Expand Up @@ -337,7 +337,7 @@ def randomSplit(self, weights, seed=None):
cweights.append(cweights[-1] + w / s)
if seed is None:
seed = random.randint(0, 2 ** 32 - 1)
return [self.mapPartitionsWithIndex(RDDSampler(False, ub, seed, lb).func, True)
return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
for lb, ub in zip(cweights, cweights[1:])]

# this is ported from scala/spark/RDD.scala
Expand Down
18 changes: 15 additions & 3 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,9 @@ def shuffle(self, vals):

class RDDSampler(RDDSamplerBase):

def __init__(self, withReplacement, fraction, seed=None, lowbound=0.0):
def __init__(self, withReplacement, fraction, seed=None):
RDDSamplerBase.__init__(self, withReplacement, seed)
self._fraction = fraction
self._lowbound = lowbound

def func(self, split, iterator):
if self._withReplacement:
Expand All @@ -112,10 +111,23 @@ def func(self, split, iterator):
yield obj
else:
for obj in iterator:
if self._lowbound <= self.getUniformSample(split) < self._fraction:
if self.getUniformSample(split) <= self._fraction:
yield obj


class RDDRangeSampler(RDDSamplerBase):

def __init__(self, lowerBound, upperBound, seed=None):
RDDSamplerBase.__init__(self, False, seed)
self._lowerBound = lowerBound
self._upperBound = upperBound

def func(self, split, iterator):
for obj in iterator:
if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
yield obj


class RDDStratifiedSampler(RDDSamplerBase):

def __init__(self, withReplacement, fractions, seed=None):
Expand Down

0 comments on commit 4dfa2cd

Please sign in to comment.