Skip to content

Commit

Permalink
SPARK-2028: Expose mapPartitionsWithInputSplit in HadoopRDD
Browse files Browse the repository at this point in the history
This allows users to gain access to the InputSplit which backs each partition.

An alternative solution would have been to have a .withInputSplit() method which returns a new RDD[(InputSplit, (K, V))], but this is confusing because you could not cache this RDD or shuffle it, as InputSplit is not inherently serializable.

Author: Aaron Davidson <aaron@databricks.com>

Closes #973 from aarondav/hadoop and squashes the following commits:

9c9112b [Aaron Davidson] Add JavaAPISuite test
9942cd7 [Aaron Davidson] Add Java API
1284a3a [Aaron Davidson] SPARK-2028: Expose mapPartitionsWithInputSplit in HadoopRDD
  • Loading branch information
aarondav authored and mateiz committed Jul 31, 2014
1 parent 72cfb13 commit f193312
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 11 deletions.
43 changes: 43 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java

import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import org.apache.hadoop.mapred.InputSplit

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.api.java.function.{Function2 => JFunction2}
import org.apache.spark.rdd.HadoopRDD

@DeveloperApi
class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V])
(implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V])
extends JavaPairRDD[K, V](rdd) {

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
@DeveloperApi
def mapPartitionsWithInputSplit[R](
f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] = {
new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)),
preservesPartitioning)(fakeClassTag))(fakeClassTag)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java

import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import org.apache.hadoop.mapreduce.InputSplit

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.api.java.function.{Function2 => JFunction2}
import org.apache.spark.rdd.NewHadoopRDD

@DeveloperApi
class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V])
(implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V])
extends JavaPairRDD[K, V](rdd) {

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
@DeveloperApi
def mapPartitionsWithInputSplit[R](
f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] = {
new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)),
preservesPartitioning)(fakeClassTag))(fakeClassTag)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark._
import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}

/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
Expand Down Expand Up @@ -294,7 +294,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions))
val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}

/**
Expand All @@ -314,7 +315,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}

/** Get an RDD for a Hadoop file with an arbitrary InputFormat.
Expand All @@ -333,7 +335,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions))
val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}

/** Get an RDD for a Hadoop file with an arbitrary InputFormat
Expand All @@ -351,8 +354,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path,
inputFormatClass, keyClass, valueClass))
val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}

/**
Expand All @@ -372,7 +375,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
conf: Configuration): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(kClass)
implicit val ctagV: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
val rdd = sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)
new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]])
}

/**
Expand All @@ -391,7 +395,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
vClass: Class[V]): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(kClass)
implicit val ctagV: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
val rdd = sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)
new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]])
}

/** Build the union of two or more RDDs. */
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
import java.io.EOFException

import scala.collection.immutable.Map
import scala.reflect.ClassTag

import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
Expand All @@ -39,6 +41,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.NextIterator

/**
Expand Down Expand Up @@ -232,6 +235,14 @@ class HadoopRDD[K, V](
new InterruptibleIterator[(K, V)](context, iter)
}

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
@DeveloperApi
def mapPartitionsWithInputSplit[U: ClassTag](
f: (InputSplit, Iterator[(K, V)]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
new HadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}

override def getPreferredLocations(split: Partition): Seq[String] = {
// TODO: Filtering out "localhost" in case of file:// URLs
val hadoopSplit = split.asInstanceOf[HadoopPartition]
Expand Down Expand Up @@ -272,4 +283,25 @@ private[spark] object HadoopRDD {
conf.setInt("mapred.task.partition", splitId)
conf.set("mapred.job.id", jobID.toString)
}

/**
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
* the given function rather than the index of the partition.
*/
private[spark] class HadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (InputSplit, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {

override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext) = {
val partition = split.asInstanceOf[HadoopPartition]
val inputSplit = partition.inputSplit.value
f(inputSplit, firstParent[T].iterator(split, context))
}
}
}
34 changes: 34 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date

import scala.reflect.ClassTag

import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
Expand All @@ -32,6 +34,7 @@ import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD

private[spark] class NewHadoopPartition(
rddId: Int,
Expand Down Expand Up @@ -157,6 +160,14 @@ class NewHadoopRDD[K, V](
new InterruptibleIterator(context, iter)
}

/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
@DeveloperApi
def mapPartitionsWithInputSplit[U: ClassTag](
f: (InputSplit, Iterator[(K, V)]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}

override def getPreferredLocations(split: Partition): Seq[String] = {
val theSplit = split.asInstanceOf[NewHadoopPartition]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
Expand All @@ -165,6 +176,29 @@ class NewHadoopRDD[K, V](
def getConf: Configuration = confBroadcast.value.value
}

private[spark] object NewHadoopRDD {
/**
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
* the given function rather than the index of the partition.
*/
private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (InputSplit, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {

override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext) = {
val partition = split.asInstanceOf[NewHadoopPartition]
val inputSplit = partition.serializableHadoopSplit.value
f(inputSplit, firstParent[T].iterator(split, context))
}
}
}

private[spark] class WholeTextFileRDD(
sc : SparkContext,
inputFormatClass: Class[_ <: WholeTextFileInputFormat],
Expand Down
26 changes: 25 additions & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,31 @@
import scala.Tuple3;
import scala.Tuple4;


import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.hadoop.mapreduce.Job;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaHadoopRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
Expand Down Expand Up @@ -1262,4 +1267,23 @@ public void collectUnderlyingScalaRDD() {
SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
Assert.assertEquals(data.size(), collected.length);
}

public void getHadoopInputSplits() {
String outDir = new File(tempDir, "output").getAbsolutePath();
sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).saveAsTextFile(outDir);

JavaHadoopRDD<LongWritable, Text> hadoopRDD = (JavaHadoopRDD<LongWritable, Text>)
sc.hadoopFile(outDir, TextInputFormat.class, LongWritable.class, Text.class);
List<String> inputPaths = hadoopRDD.mapPartitionsWithInputSplit(
new Function2<InputSplit, Iterator<Tuple2<LongWritable, Text>>, Iterator<String>>() {
@Override
public Iterator<String> call(InputSplit split, Iterator<Tuple2<LongWritable, Text>> it)
throws Exception {
FileSplit fileSplit = (FileSplit) split;
return Lists.newArrayList(fileSplit.getPath().toUri().getPath()).iterator();
}
}, true).collect();
Assert.assertEquals(Sets.newHashSet(inputPaths),
Sets.newHashSet(outDir + "/part-00000", outDir + "/part-00001"));
}
}
34 changes: 32 additions & 2 deletions core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import scala.io.Source
import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, TextOutputFormat}
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat}
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD}
import org.apache.spark.util.Utils

class FileSuite extends FunSuite with LocalSparkContext {
Expand Down Expand Up @@ -318,4 +320,32 @@ class FileSuite extends FunSuite with LocalSparkContext {
randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration)
assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true)
}

test("Get input files via old Hadoop API") {
sc = new SparkContext("local", "test")
val outDir = new File(tempDir, "output").getAbsolutePath
sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir)

val inputPaths =
sc.hadoopFile(outDir, classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
.asInstanceOf[HadoopRDD[_, _]]
.mapPartitionsWithInputSplit { (split, part) =>
Iterator(split.asInstanceOf[FileSplit].getPath.toUri.getPath)
}.collect()
assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001"))
}

test("Get input files via new Hadoop API") {
sc = new SparkContext("local", "test")
val outDir = new File(tempDir, "output").getAbsolutePath
sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir)

val inputPaths =
sc.newAPIHadoopFile(outDir, classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text])
.asInstanceOf[NewHadoopRDD[_, _]]
.mapPartitionsWithInputSplit { (split, part) =>
Iterator(split.asInstanceOf[NewFileSplit].getPath.toUri.getPath)
}.collect()
assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001"))
}
}

0 comments on commit f193312

Please sign in to comment.