Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangxb1987 committed Aug 10, 2018
1 parent d508fc5 commit ea2330b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 21 deletions.
Expand Up @@ -45,7 +45,8 @@ import org.apache.spark.util._
private[spark] class PythonRDD(
parent: RDD[_],
func: PythonFunction,
preservePartitoning: Boolean)
preservePartitoning: Boolean,
isFromBarrier: Boolean = false)
extends RDD[Array[Byte]](parent) {

val bufferSize = conf.getInt("spark.buffer.size", 65536)
Expand All @@ -63,6 +64,9 @@ private[spark] class PythonRDD(
val runner = PythonRunner(func, bufferSize, reuseWorker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}

@transient protected lazy override val isBarrier_ : Boolean =
isFromBarrier || dependencies.exists(_.rdd.isBarrier())
}

/**
Expand Down
14 changes: 0 additions & 14 deletions core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
Expand Up @@ -21,7 +21,6 @@ import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD

/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */
class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
Expand All @@ -47,18 +46,5 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {
)
}

/**
* Expose a JavaRDD that wraps a barrier RDD generated from the prev RDD, to support launch
* barrier stage from python side.
*/
private[spark] def toJavaRDD(): JavaRDD[T] = {
val barrierRDD = new MapPartitionsRDD[T, T](
rdd,
(context, pid, iter) => iter,
preservesPartitioning = false,
isFromBarrier = true)
JavaRDD.fromRDD(barrierRDD)
}

/** TODO extra conf(e.g. timeout) */
}
13 changes: 7 additions & 6 deletions python/pyspark/rdd.py
Expand Up @@ -2462,7 +2462,6 @@ class RDDBarrier(object):

def __init__(self, rdd):
self.rdd = rdd
self._jrdd = rdd._jrdd

def mapPartitions(self, f, preservesPartitioning=False):
"""
Expand All @@ -2474,9 +2473,7 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
def func(s, iterator):
return f(iterator)
jBarrierRdd = self._jrdd.rdd().barrier().toJavaRDD()
pyBarrierRdd = RDD(jBarrierRdd, self.rdd.ctx, self.rdd._jrdd_deserializer)
return pyBarrierRdd.mapPartitions(f, preservesPartitioning)
return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True)


class PipelinedRDD(RDD):
Expand All @@ -2498,7 +2495,7 @@ class PipelinedRDD(RDD):
20
"""

def __init__(self, prev, func, preservesPartitioning=False):
def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False):
if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
# This transformation is the first in its stage:
self.func = func
Expand All @@ -2524,6 +2521,7 @@ def pipeline_func(split, iterator):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self.partitioner = prev.partitioner if self.preservesPartitioning else None
self.is_barrier = prev.isBarrier() or isFromBarrier

def getNumPartitions(self):
return self._prev_jrdd.partitions().size()
Expand All @@ -2543,7 +2541,7 @@ def _jrdd(self):
wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer, profiler)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
self.preservesPartitioning)
self.preservesPartitioning, self.is_barrier)
self._jrdd_val = python_rdd.asJavaRDD()

if profiler:
Expand All @@ -2559,6 +2557,9 @@ def id(self):
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)

def isBarrier(self):
return self.is_barrier


def _test():
import doctest
Expand Down

0 comments on commit ea2330b

Please sign in to comment.