Skip to content

Commit

Permalink
[SPARK-2737] Add retag() method for changing RDDs' ClassTags.
Browse files Browse the repository at this point in the history
The Java API's use of fake ClassTags doesn't seem to cause any problems for
Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to
Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on
a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we
try to allocate an array of the wrong type (for example, see SPARK-2197).

There are a few possible fixes here. An API-breaking fix would be to completely
remove the fake ClassTags and require Java API users to pass java.lang.Class
instances to all parallelize() calls and add returnClass fields to all Function
implementations. This would be extremely verbose.

Instead, this patch adds internal APIs to "repair" a Scala RDD with an
incorrect ClassTag by wrapping it and overriding its ClassTag. This should be
okay for cases where the Scala code that calls collect() knows what type of
array should be allocated, which is the case in the MLlib wrappers.
  • Loading branch information
JoshRosen committed Jul 29, 2014
1 parent dc96536 commit d1d54e6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
22 changes: 22 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,28 @@ abstract class RDD[T: ClassTag](
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context = sc

/**
* Private API for changing an RDD's ClassTag.
* Used for internal Java <-> Scala API compatibility.
*/
private[spark] def retag(cls: Class[T]): RDD[T] = {
val classTag: ClassTag[T] = ClassTag.apply(cls)
this.retag(classTag)
}

/**
* Private API for changing an RDD's ClassTag.
* Used for internal Java <-> Scala API compatibility.
*/
private[spark] def retag(classTag: ClassTag[T]): RDD[T] = {
val oldRDD = this
new RDD[T](sc, Seq(new OneToOneDependency(this)))(classTag) {
override protected def getPartitions: Array[Partition] = oldRDD.getPartitions
override def compute(split: Partition, context: TaskContext): Iterator[T] =
oldRDD.compute(split, context)
}
}

// Avoid handling doCheckpoint multiple times to prevent excessive recursion
@transient private var doCheckpointCalled = false

Expand Down
17 changes: 17 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1245,4 +1245,21 @@ public Tuple2<Integer, Integer> call(Integer i) {
Assert.assertTrue(worExactCounts.get(0) == 2);
Assert.assertTrue(worExactCounts.get(1) == 4);
}

private static class SomeCustomClass implements Serializable {
public SomeCustomClass() {
// Intentionally left blank
}
}

@Test
public void collectUnderlyingScalaRDD() {
List<SomeCustomClass> data = new ArrayList<SomeCustomClass>();
for (int i = 0; i < 100; i++) {
data.add(new SomeCustomClass());
}
JavaRDD<SomeCustomClass> rdd = sc.parallelize(data);
SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
Assert.assertEquals(data.size(), collected.length);
}
}

0 comments on commit d1d54e6

Please sign in to comment.