From e114eeddedd02547e0b57bd9a00291885b116daa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Jan 2017 04:47:39 +0800 Subject: [PATCH 1/2] init pr --- .../org/apache/spark/rdd/CartesianRDD.scala | 43 +++++++++++++++---- .../scala/org/apache/spark/rdd/RDDSuite.scala | 17 ++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 57108dcedcf0c..80fc868df5dc7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -20,8 +20,8 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag - import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils private[spark] @@ -53,6 +53,9 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[(T, U)](sc, Nil) with Serializable { + var rdd1WithPid = rdd1.mapPartitionsWithIndex((pid, iter) => iter.map(x => (pid, x))) + var rdd2WithPid = rdd2.mapPartitionsWithIndex((pid, iter) => iter.map(x => (pid, x))) + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { @@ -70,24 +73,46 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } + private def getParititionIterator[T: ClassTag]( + dependency: ShuffleDependency[Int, T, T], partitionIndex: Int, context: TaskContext) = { + SparkEnv.get.shuffleManager + .getReader[Int, T](dependency.shuffleHandle, partitionIndex, partitionIndex + 1, context) + .read().map(x => x._2) + } + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] - for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + for (x <- getParititionIterator[T](dependencies(0).asInstanceOf[ShuffleDependency[Int, T, T]], + currSplit.s1.index, context); + y <- getParititionIterator[U](dependencies(1).asInstanceOf[ShuffleDependency[Int, U, U]], + currSplit.s2.index, context)) + yield (x, y) } - override def getDependencies: Seq[Dependency[_]] = List( - new NarrowDependency(rdd1) { - def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) - }, - new NarrowDependency(rdd2) { - def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) + private def getShufflePartitioner(numParts: Int): Partitioner = { + return new Partitioner { + require(numPartitions > 0) + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + override def numPartitions: Int = numParts } + } + + private var serializer: Serializer = SparkEnv.get.serializer + + override def getDependencies: Seq[Dependency[_]] = List( + new ShuffleDependency[Int, T, T]( + rdd1WithPid.asInstanceOf[RDD[_ <: Product2[Int, T]]], + getShufflePartitioner(rdd1WithPid.getNumPartitions), serializer), + new ShuffleDependency[Int, U, U]( + rdd2WithPid.asInstanceOf[RDD[_ <: Product2[Int, U]]], + getShufflePartitioner(rdd2WithPid.getNumPartitions), serializer) ) override def clearDependencies() { super.clearDependencies() rdd1 = null rdd2 = null + rdd1WithPid = null + rdd2WithPid = null } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c85..fbbfe328efa31 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -230,6 +230,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("cartesian") { + val seq1 = Seq(1, 2, 3, 4) + val seq2 = Seq('a', 'b', 'c', 'd', 'e', 'f') + + val rdd1 = sc.makeRDD(seq1, 2) + val rdd2 = sc.makeRDD(seq2, 3) + + val result = rdd1.cartesian(rdd2).collect().sortWith((x, y) => { + if (x._1 != y._1) x._1 < y._1 + else x._2 < y._2 + }).toList + + val expectedResult = (for (i <- seq1; j <- seq2) yield (i, j)).toList + + assert(result === expectedResult) + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) From 815063b5127857b3e2a76f19ee945ff54d8dd110 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Jan 2017 05:05:14 +0800 Subject: [PATCH 2/2] fix scala style check --- .../src/main/scala/org/apache/spark/rdd/CartesianRDD.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 80fc868df5dc7..08c996125d474 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -20,12 +20,13 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag -import org.apache.spark._ + +import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkContext, + SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils -private[spark] -class CartesianPartition( +private[spark] class CartesianPartition( idx: Int, @transient private val rdd1: RDD[_], @transient private val rdd2: RDD[_],