Skip to content

Commit

Permalink
[SPARK-30294][SS][FOLLOW-UP] Directly override RDD methods
Browse files Browse the repository at this point in the history
### Why are the changes needed?
Follow the comment: #26935 (comment)

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing test and Mima test.

Closes #30344 from xuanyuanking/SPARK-30294-follow.

Authored-by: Yuanjian Li <yuanjian.li@databricks.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
xuanyuanking authored and HyukjinKwon committed Nov 12, 2020
1 parent 6244407 commit 9f983a6
Showing 1 changed file with 3 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

// This doesn't directly override RDD methods as MiMa complains it.
abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag](
dataRDD: RDD[T],
checkpointLocation: String,
Expand All @@ -45,16 +44,13 @@ abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag](
protected val hadoopConfBroadcast = dataRDD.context.broadcast(
new SerializableConfiguration(sessionState.newHadoopConf()))

/** Implementations can simply call this method in getPreferredLocations. */
protected def _getPartitions: Array[Partition] = dataRDD.partitions

/**
* Set the preferred location of each partition using the executor that has the related
* [[StateStoreProvider]] already loaded.
*
* Implementations can simply call this method in getPreferredLocations.
*/
protected def _getPreferredLocations(partition: Partition): Seq[String] = {
override def getPreferredLocations(partition: Partition): Seq[String] = {
val stateStoreProviderId = getStateProviderId(partition)
storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq
}
Expand Down Expand Up @@ -87,10 +83,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId,
sessionState, storeCoordinator, extraOptions) {

override protected def getPartitions: Array[Partition] = _getPartitions

override def getPreferredLocations(partition: Partition): Seq[String] =
_getPreferredLocations(partition)
override protected def getPartitions: Array[Partition] = dataRDD.partitions

override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
val storeProviderId = getStateProviderId(partition)
Expand Down Expand Up @@ -124,10 +117,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId,
sessionState, storeCoordinator, extraOptions) {

override protected def getPartitions: Array[Partition] = _getPartitions

override def getPreferredLocations(partition: Partition): Seq[String] =
_getPreferredLocations(partition)
override protected def getPartitions: Array[Partition] = dataRDD.partitions

override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
val storeProviderId = getStateProviderId(partition)
Expand Down

0 comments on commit 9f983a6

Please sign in to comment.