New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2737] Add retag() method for changing RDDs' ClassTags. #1639
Conversation
/cc @mengxr @jkbradley @mateiz |
QA tests have started for PR 1639. This patch merges cleanly. |
QA results for PR 1639: |
The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197). There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose. Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers.
QA tests have started for PR 1639. This patch merges cleanly. |
QA results for PR 1639: |
Another option would be to add |
override protected def getPartitions: Array[Partition] = oldRDD.getPartitions | ||
override def compute(split: Partition, context: TaskContext): Iterator[T] = | ||
oldRDD.compute(split, context) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need to preserve the Partitioner and such. It would be better to do this via this.mapPartitions
with the preservePartitioning option set to true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would there be any performance impact of running mapPartitions(identity, preservesPartitioning = true)(classTag)
? If we have an RDD that's persisted in a serialized format, wouldn't this extra map force an unnecessary deserialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the fix with just passing the partitioner also works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually compute just works at the iterator level, so I don't think mapPartitions would hurt. All you do is pass through the parent's iterator. When you call compute() you're already deserializing the RDD, this won't create extra work.
I'm okay with either this or collectSeq actually. |
I'm going to take another pass on this to see if I can implicitly grab the ClassTag from the caller's scope, so hold off on merging this for a bit. |
QA tests have started for PR 1639. This patch merges cleanly. |
This method is intended to be called by Scala classes that implement Java-friendly wrappers for the Spark Scala API. For instance, MLlib has APIs that accept RDD[LabelledPoint]. Ideally, the Java wrapper code can simply call the underlying Scala methods without having to worry about how they're implemented. Therefore, I think we should prefer the Since this is a private, internal API, we should be able to revisit this decision if we change our minds later. |
My last commit made |
Sure, sounds good. Did you see my comments on preserving partitions too though? |
QA results for PR 1639: |
QA tests have started for PR 1639. This patch merges cleanly. |
QA results for PR 1639: |
In case you don't see the hidden comment above: I don't think mapPartitions would hurt performance here. All you do is pass through the parent's iterator. When you call compute() you're already deserializing the RDD, so this won't create extra work in that case. |
Basically it's a shorter way of writing what you wrote. Take a look at MapPartitionsRDD. |
I've updated this to use mapPartitions(). |
QA tests have started for PR 1639. This patch merges cleanly. |
LGTM, feel free to merge it when it passes tests |
Jenkins, retest this please. |
QA tests have started for PR 1639. This patch merges cleanly. |
QA results for PR 1639: |
Jenkins, retest this please. |
QA tests have started for PR 1639. This patch merges cleanly. |
QA results for PR 1639: |
Alright, I've merged this. Thanks for the review! |
The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197). There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose. Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers. Author: Josh Rosen <joshrosen@apache.org> Closes apache#1639 from JoshRosen/SPARK-2737 and squashes the following commits: 572b4c8 [Josh Rosen] Replace newRDD[T] with mapPartitions(). 469d941 [Josh Rosen] Preserve partitioner in retag(). af78816 [Josh Rosen] Allow retag() to get classTag implicitly. d1d54e6 [Josh Rosen] [SPARK-2737] Add retag() method for changing RDDs' ClassTags.
The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197).
There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose.
Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers.