Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2737] Add retag() method for changing RDDs' ClassTags. #1639

Closed
wants to merge 4 commits into from

Conversation

JoshRosen
Copy link
Contributor

The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197).

There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose.

Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers.

@JoshRosen
Copy link
Contributor Author

/cc @mengxr @jkbradley @mateiz

@SparkQA
Copy link

SparkQA commented Jul 29, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17381/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 29, 2014

QA results for PR 1639:
- This patch FAILED unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17381/consoleFull

The Java API's use of fake ClassTags doesn't seem to cause any problems for
Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to
Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on
a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we
try to allocate an array of the wrong type (for example, see SPARK-2197).

There are a few possible fixes here. An API-breaking fix would be to completely
remove the fake ClassTags and require Java API users to pass java.lang.Class
instances to all parallelize() calls and add returnClass fields to all Function
implementations. This would be extremely verbose.

Instead, this patch adds internal APIs to "repair" a Scala RDD with an
incorrect ClassTag by wrapping it and overriding its ClassTag. This should be
okay for cases where the Scala code that calls collect() knows what type of
array should be allocated, which is the case in the MLlib wrappers.
@SparkQA
Copy link

SparkQA commented Jul 29, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17383/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 29, 2014

QA results for PR 1639:
- This patch PASSES unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17383/consoleFull

@marmbrus
Copy link
Contributor

Another option would be to add collectSeq or something similar that returns a type with reasonable variance semantics.

override protected def getPartitions: Array[Partition] = oldRDD.getPartitions
override def compute(split: Partition, context: TaskContext): Iterator[T] =
oldRDD.compute(split, context)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to preserve the Partitioner and such. It would be better to do this via this.mapPartitions with the preservePartitioning option set to true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be any performance impact of running mapPartitions(identity, preservesPartitioning = true)(classTag)? If we have an RDD that's persisted in a serialized format, wouldn't this extra map force an unnecessary deserialization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the fix with just passing the partitioner also works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually compute just works at the iterator level, so I don't think mapPartitions would hurt. All you do is pass through the parent's iterator. When you call compute() you're already deserializing the RDD, this won't create extra work.

@mateiz
Copy link
Contributor

mateiz commented Jul 30, 2014

I'm okay with either this or collectSeq actually.

@JoshRosen
Copy link
Contributor Author

I'm going to take another pass on this to see if I can implicitly grab the ClassTag from the caller's scope, so hold off on merging this for a bit.

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17418/consoleFull

@JoshRosen
Copy link
Contributor Author

This method is intended to be called by Scala classes that implement Java-friendly wrappers for the Spark Scala API. For instance, MLlib has APIs that accept RDD[LabelledPoint]. Ideally, the Java wrapper code can simply call the underlying Scala methods without having to worry about how they're implemented. Therefore, I think we should prefer the retag()-based approach, since collectSeq would require us to modify the Scala consumer of the RDD.

Since this is a private, internal API, we should be able to revisit this decision if we change our minds later.

@JoshRosen
Copy link
Contributor Author

My last commit made classTag implicit in the retag() method, so in many cases the Scala code can be written as someJavaRDD.rdd.retag.[...].collect().

@mateiz
Copy link
Contributor

mateiz commented Jul 30, 2014

Sure, sounds good. Did you see my comments on preserving partitions too though?

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA results for PR 1639:
- This patch PASSES unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17418/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17426/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA results for PR 1639:
- This patch FAILED unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17426/consoleFull

@mateiz
Copy link
Contributor

mateiz commented Jul 30, 2014

In case you don't see the hidden comment above: I don't think mapPartitions would hurt performance here. All you do is pass through the parent's iterator. When you call compute() you're already deserializing the RDD, so this won't create extra work in that case.

@mateiz
Copy link
Contributor

mateiz commented Jul 30, 2014

Basically it's a shorter way of writing what you wrote. Take a look at MapPartitionsRDD.

@JoshRosen
Copy link
Contributor Author

I've updated this to use mapPartitions().

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17458/consoleFull

@mateiz
Copy link
Contributor

mateiz commented Jul 30, 2014

LGTM, feel free to merge it when it passes tests

@JoshRosen
Copy link
Contributor Author

Jenkins, retest this please.

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17474/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 30, 2014

QA results for PR 1639:
- This patch FAILED unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17474/consoleFull

@JoshRosen
Copy link
Contributor Author

Jenkins, retest this please.

@SparkQA
Copy link

SparkQA commented Jul 31, 2014

QA tests have started for PR 1639. This patch merges cleanly.
View progress: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17550/consoleFull

@SparkQA
Copy link

SparkQA commented Jul 31, 2014

QA results for PR 1639:
- This patch PASSES unit tests.
- This patch merges cleanly
- This patch adds no public classes

For more information see test ouptut:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/17550/consoleFull

@JoshRosen
Copy link
Contributor Author

Alright, I've merged this. Thanks for the review!

@asfgit asfgit closed this in 4fb2593 Jul 31, 2014
xiliu82 pushed a commit to xiliu82/spark that referenced this pull request Sep 4, 2014
The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197).

There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose.

Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers.

Author: Josh Rosen <joshrosen@apache.org>

Closes apache#1639 from JoshRosen/SPARK-2737 and squashes the following commits:

572b4c8 [Josh Rosen] Replace newRDD[T] with mapPartitions().
469d941 [Josh Rosen] Preserve partitioner in retag().
af78816 [Josh Rosen] Allow retag() to get classTag implicitly.
d1d54e6 [Josh Rosen] [SPARK-2737] Add retag() method for changing RDDs' ClassTags.
sunchao added a commit to sunchao/spark that referenced this pull request Jun 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants