Skip to content

Commit

Permalink
switch to python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 13, 2014
1 parent 95a48ac commit c7a2007
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -758,13 +758,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}

/**
* A helper to convert java.util.List[Double] into Array[Double]
*/
def listToArrayDouble(list: JList[Double]): Array[Double] = {
list.asScala.toArray
}
}

private
Expand Down
30 changes: 13 additions & 17 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
import heapq
import bisect
from random import Random
import random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
Expand Down Expand Up @@ -324,25 +324,21 @@ def randomSplit(self, weights, seed=None):
:param seed: random seed
:return: split RDDs in a list
>>> rdd = sc.parallelize(range(10), 1)
>>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
>>> rdd = sc.parallelize(range(5), 1)
>>> rdd1, rdd2 = rdd.randomSplit([2.0, 3.0], 101)
>>> rdd1.collect()
[3, 6]
[2, 3]
>>> rdd2.collect()
[0, 5, 7]
>>> rdd3.collect()
[1, 2, 4, 8, 9]
[0, 1, 4]
"""
ser = BatchedSerializer(PickleSerializer(), 1)
rdd = self._reserialize(ser)
jweights = ListConverter().convert([float(w) for w in weights],
self.ctx._gateway._gateway_client)
jweights = self.ctx._jvm.PythonRDD.listToArrayDouble(jweights)
s = sum(weights)
cweights = [0.0]
for w in weights:
cweights.append(cweights[-1] + w / s)
if seed is None:
jrdds = rdd._jrdd.randomSplit(jweights)
else:
jrdds = rdd._jrdd.randomSplit(jweights, seed)
return [RDD(jrdd, self.ctx, ser) for jrdd in jrdds]
seed = random.randint(0, 2 ** 32 - 1)
return [self.mapPartitionsWithIndex(RDDSampler(False, ub, seed, lb).func, True)
for lb, ub in zip(cweights, cweights[1:])]

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
Expand All @@ -369,7 +365,7 @@ def takeSample(self, withReplacement, num, seed=None):
if initialCount == 0:
return []

rand = Random(seed)
rand = random.Random(seed)

if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ def shuffle(self, vals):

class RDDSampler(RDDSamplerBase):

def __init__(self, withReplacement, fraction, seed=None):
def __init__(self, withReplacement, fraction, seed=None, lowbound=0.0):
RDDSamplerBase.__init__(self, withReplacement, seed)
self._fraction = fraction
self._lowbound = lowbound

def func(self, split, iterator):
if self._withReplacement:
Expand All @@ -111,7 +112,7 @@ def func(self, split, iterator):
yield obj
else:
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
if self._lowbound <= self.getUniformSample(split) < self._fraction:
yield obj


Expand Down

0 comments on commit c7a2007

Please sign in to comment.