Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryshao committed Jun 25, 2018
1 parent 6d0a2b7 commit f3c99f8
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -995,20 +995,24 @@ class SparkSubmitSuite
}

test("download remote resource if it is not supported by yarn service") {
testRemoteResources(enableHttpFs = false, blacklistHttpFs = false)
testRemoteResources(enableHttpFs = false)
}

test("avoid downloading remote resource if it is supported by yarn service") {
testRemoteResources(enableHttpFs = true, blacklistHttpFs = false)
testRemoteResources(enableHttpFs = true)
}

test("force download from blacklisted schemes") {
testRemoteResources(enableHttpFs = true, blacklistHttpFs = true)
testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http"))
}

test("force download for all the schemes") {
testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*"))
}

private def testRemoteResources(
enableHttpFs: Boolean,
blacklistHttpFs: Boolean): Unit = {
blacklistSchemes: Seq[String] = Nil): Unit = {
val hadoopConf = new Configuration()
updateConfWithFakeS3Fs(hadoopConf)
if (enableHttpFs) {
Expand All @@ -1025,8 +1029,8 @@ class SparkSubmitSuite
val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir)
val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}"

val forceDownloadArgs = if (blacklistHttpFs) {
Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http")
val forceDownloadArgs = if (blacklistSchemes.nonEmpty) {
Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}")
} else {
Nil
}
Expand All @@ -1044,14 +1048,19 @@ class SparkSubmitSuite

val jars = conf.get("spark.yarn.dist.jars").split(",").toSet

// The URI of remote S3 resource should still be remote.
assert(jars.contains(tmpS3JarPath))
def isSchemeBlacklisted(scheme: String) = {
blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme)
}

if (!isSchemeBlacklisted("s3")) {
assert(jars.contains(tmpS3JarPath))
}

if (enableHttpFs && !blacklistHttpFs) {
if (enableHttpFs && blacklistSchemes.isEmpty) {
// If Http FS is supported by yarn service, the URI of remote http resource should
// still be remote.
assert(jars.contains(tmpHttpJarPath))
} else {
} else if (!enableHttpFs || isSchemeBlacklisted("http")) {
// If Http FS is not supported by yarn service, or http scheme is configured to be force
// downloading, the URI of remote http resource should be changed to a local one.
val jarName = new File(tmpHttpJar.toURI).getName
Expand Down

0 comments on commit f3c99f8

Please sign in to comment.