Skip to content

Commit

Permalink
[SPARK-25100][TEST][FOLLOWUP] Refactor test cases in FileSuite and …
Browse files Browse the repository at this point in the history
…`KryoSerializerSuite`

### What changes were proposed in this pull request?

Refactor test cases added by #26714, to improve code compactness.

### How was this patch tested?

Tested locally.

Closes #26916 from jiangxb1987/SPARK-25100.

Authored-by: Xingbo Jiang <xingbo.jiang@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
jiangxb1987 authored and dongjoon-hyun committed Dec 17, 2019
1 parent 1da7e82 commit 1c714be
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
41 changes: 24 additions & 17 deletions core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -702,32 +702,39 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty)
}

test("SPARK-25100: Using KryoSerializer and" +
"setting registrationRequired true can lead job failed") {
val inputFile = new File(tempDir, "/input").getAbsolutePath
val textFileOutputDir = new File(tempDir, "/out1").getAbsolutePath
val dataSetDir = new File(tempDir, "/out2").getAbsolutePath

Utils.tryWithResource(new PrintWriter(new File(inputFile))) { writer =>
for (i <- 1 to 100) {
test("SPARK-25100: Support commit tasks when Kyro registration is required") {
// Prepare the input file
val inputFilePath = new File(tempDir, "/input").getAbsolutePath
Utils.tryWithResource(new PrintWriter(new File(inputFilePath))) { writer =>
for (i <- 1 to 3) {
writer.print(i)
writer.write('\n')
}
}

val conf = new SparkConf(false).setMaster("local").
set("spark.kryo.registrationRequired", "true").setAppName("test")
conf.set("spark.serializer", classOf[KryoSerializer].getName)
// Start a new SparkContext
val conf = new SparkConf(false)
.setMaster("local")
.setAppName("test")
.set("spark.kryo.registrationRequired", "true")
.set("spark.serializer", classOf[KryoSerializer].getName)
sc = new SparkContext(conf)

// Prepare the input RDD
val pairRDD = sc.textFile(inputFilePath).map(x => (x, x))

// Test saveAsTextFile()
val outputFilePath1 = new File(tempDir, "/out1").getAbsolutePath
pairRDD.saveAsTextFile(outputFilePath1)
assert(sc.textFile(outputFilePath1).collect() === Array("(1,1)", "(2,2)", "(3,3)"))

// Test saveAsNewAPIHadoopDataset()
val outputFilePath2 = new File(tempDir, "/out2").getAbsolutePath
val jobConf = new JobConf()
jobConf.setOutputKeyClass(classOf[IntWritable])
jobConf.setOutputValueClass(classOf[IntWritable])
jobConf.set("mapred.output.dir", dataSetDir)

sc = new SparkContext(conf)
val pairRDD = sc.textFile(inputFile).map(x => (x, 1))

pairRDD.saveAsTextFile(textFileOutputDir)
jobConf.set("mapred.output.dir", outputFilePath2)
pairRDD.saveAsNewAPIHadoopDataset(jobConf)
assert(sc.textFile(outputFilePath2).collect() === Array("1\t1", "2\t2", "3\t3"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,14 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val conf = new SparkConf(false)
conf.set(KRYO_REGISTRATION_REQUIRED, true)

val ser = new KryoSerializer(conf).newInstance()
// In HadoopMapReduceCommitProtocol#commitTask
val addedAbsPathFiles: mutable.Map[String, String] = mutable.Map()
addedAbsPathFiles.put("test1", "test1")
addedAbsPathFiles.put("test2", "test2")
// HadoopMapReduceCommitProtocol.commitTask() returns a TaskCommitMessage containing a complex
// structure.

val partitionPaths: mutable.Set[String] = mutable.Set()
partitionPaths.add("test3")
val ser = new KryoSerializer(conf).newInstance()
val addedAbsPathFiles = Map("test1" -> "test1", "test2" -> "test2")
val partitionPaths = Set("test3")

val taskCommitMessage1 = new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet)
val taskCommitMessage1 = new TaskCommitMessage(addedAbsPathFiles -> partitionPaths)
val taskCommitMessage2 = new TaskCommitMessage(Map.empty -> Set.empty)
Seq(taskCommitMessage1, taskCommitMessage2).foreach { taskCommitMessage =>
val obj1 = ser.deserialize[TaskCommitMessage](ser.serialize(taskCommitMessage)).obj
Expand Down

0 comments on commit 1c714be

Please sign in to comment.