Skip to content

Commit

Permalink
making classes that needn't be public private, adding automatic file …
Browse files Browse the repository at this point in the history
…closure, adding new tests
  • Loading branch information
kmader committed Aug 14, 2014
1 parent edf5829 commit 9a313d5
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 19 deletions.
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ class SparkContext(config: SparkConf) extends Logging {
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
new RawFileRDD(
new BinaryFileRDD(
this,
classOf[ByteInputFormat],
classOf[String],
Expand All @@ -548,7 +548,7 @@ class SparkContext(config: SparkConf) extends Logging {
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
new RawFileRDD(
new BinaryFileRDD(
this,
classOf[StreamInputFormat],
classOf[String],
Expand All @@ -565,9 +565,9 @@ class SparkContext(config: SparkConf) extends Logging {
* @param path Directory to the input data files
* @return An RDD of data with values, RDD[(Array[Byte])]
*/
def fixedLengthBinaryFiles(path: String): RDD[Array[Byte]] = {
val lines = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
val data = lines.map{ case (k, v) => v.getBytes}
def binaryRecords(path: String): RDD[Array[Byte]] = {
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
val data = br.map{ case (k, v) => v.getBytes}
data
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* @param minPartitions A suggestion value of the minimal splitting number for input data.
*/
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
JavaPairRDD[String,Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))

/**
* Load data from a flat binary file, assuming each record is a set of numbers
Expand All @@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* @param path Directory to the input data files
* @return An RDD of data with values, JavaRDD[(Array[Byte])]
*/
def fixedLengthBinaryFiles(path: String): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.fixedLengthBinaryFiles(path))
def binaryRecords(path: String): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.binaryRecords(path))
}

/** Get an RDD for a Hadoop SequenceFile with given key and value types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAt
* a parameter recordLength in the Hadoop configuration.
*/

object FixedLengthBinaryInputFormat {
private[spark] object FixedLengthBinaryInputFormat {

/**
* This function retrieves the recordLength by checking the configuration parameter
Expand All @@ -42,7 +42,7 @@ object FixedLengthBinaryInputFormat {

}

class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] {
private[spark] class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] {


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit
* VALUE = the record itself (BytesWritable)
*
*/
class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] {
private[spark] class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] {

override def close() {
if (fileInputStream != null) {
Expand Down
23 changes: 17 additions & 6 deletions core/src/main/scala/org/apache/spark/input/RawFileInput.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,29 @@ abstract class StreamBasedRecordReader[T](

private val key = path.toString
private var value: T = null.asInstanceOf[T]
// the file to be read when nextkeyvalue is called
private lazy val fileIn: FSDataInputStream = fs.open(path)

override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
override def close() = {}
override def close() = {
// make sure the file is closed
try {
fileIn.close()
} catch {
case ioe: java.io.IOException => // do nothing
}
}

override def getProgress = if (processed) 1.0f else 0.0f

override def getCurrentKey = key

override def getCurrentValue = value


override def nextKeyValue = {
if (!processed) {
val fileIn: FSDataInputStream = fs.open(path)

value = parseStream(fileIn)
processed = true
true
Expand All @@ -104,7 +115,7 @@ abstract class StreamBasedRecordReader[T](
/**
* Reads the record in directly as a stream for other objects to manipulate and handle
*/
class StreamRecordReader(
private[spark] class StreamRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
Expand All @@ -117,7 +128,7 @@ class StreamRecordReader(
* A class for extracting the information from the file using the
* BinaryRecordReader (as Byte array)
*/
class StreamInputFormat extends StreamFileInputFormat[DataInputStream] {
private[spark] class StreamInputFormat extends StreamFileInputFormat[DataInputStream] {
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
{
new CombineFileRecordReader[String,DataInputStream](
Expand Down Expand Up @@ -146,7 +157,7 @@ abstract class BinaryRecordReader[T](
}


class ByteRecordReader(
private[spark] class ByteRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
Expand All @@ -158,7 +169,7 @@ class ByteRecordReader(
/**
* A class for reading the file using the BinaryRecordReader (as Byte array)
*/
class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
{
new CombineFileRecordReader[String,Array[Byte]](
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ package org.apache.spark.rdd
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.spark.{Partition, SparkContext}
import org.apache.spark.{InterruptibleIterator, TaskContext, Partition, SparkContext}
import org.apache.spark.input.StreamFileInputFormat

private[spark] class RawFileRDD[T](
private[spark] class BinaryFileRDD[T](
sc : SparkContext,
inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
keyClass: Class[String],
Expand All @@ -35,6 +35,7 @@ private[spark] class RawFileRDD[T](
minPartitions: Int)
extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {


override def getPartitions: Array[Partition] = {
val inputFormat = inputFormatClass.newInstance
inputFormat match {
Expand Down
22 changes: 22 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,28 @@ public Tuple2<Integer, String> call(Tuple2<IntWritable, Text> pair) {
Assert.assertEquals(pairs, readRDD.collect());
}

@Test
public void binaryFiles() throws Exception {
// Reusing the wholeText files example
byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8");

String tempDirName = tempDir.getAbsolutePath();
File file1 = new File(tempDirName + "/part-00000");
Files.write(content1, file1);
File file2 = new File(tempDirName + "/part-00001");
Files.write(content2, file2);

JavaPairRDD<String, byte[]> readRDD = sc.binaryFiles(tempDirName,3);
List<Tuple2<String, byte[]>> result = readRDD.collect();
for (Tuple2<String, byte[]> res : result) {
if (res._1()==file1.toString())
Assert.assertArrayEquals(content1,res._2());
else
Assert.assertArrayEquals(content2,res._2());
}
}

@SuppressWarnings("unchecked")
@Test
public void writeWithNewAPIHadoopFile() {
Expand Down
54 changes: 54 additions & 0 deletions core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,60 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}

test("byte stream input") {
sc = new SparkContext("local", "test")
val outputDir = new File(tempDir, "output").getAbsolutePath
val outFile = new File(outputDir, "part-00000.bin")
val outFileName = outFile.toPath().toString()

// create file
val testOutput = Array[Byte](1,2,3,4,5,6)
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
// write data to file
val file = new java.io.FileOutputStream(outFile)
val channel = file.getChannel
channel.write(bbuf)
channel.close()
file.close()

val inRdd = sc.binaryFiles(outFileName)
val (infile: String, indata: Array[Byte]) = inRdd.first

// Try reading the output back as an object file
assert(infile === outFileName)
assert(indata === testOutput)
}

test("fixed length byte stream input") {
// a fixed length of 6 bytes

sc = new SparkContext("local", "test")

val outputDir = new File(tempDir, "output").getAbsolutePath
val outFile = new File(outputDir, "part-00000.bin")
val outFileName = outFile.toPath().toString()

// create file
val testOutput = Array[Byte](1,2,3,4,5,6)
val testOutputCopies = 10
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
// write data to file
val file = new java.io.FileOutputStream(outFile)
val channel = file.getChannel
for(i <- 1 to testOutputCopies) channel.write(bbuf)
channel.close()
file.close()
sc.hadoopConfiguration.setInt("recordLength",testOutput.length)

val inRdd = sc.binaryRecords(outFileName)
// make sure there are enough elements
assert(inRdd.count== testOutputCopies)

// now just compare the first one
val indata: Array[Byte] = inRdd.first
assert(indata === testOutput)
}

test("file caching") {
sc = new SparkContext("local", "test")
val out = new FileWriter(tempDir + "/input")
Expand Down

0 comments on commit 9a313d5

Please sign in to comment.