Skip to content

Commit

Permalink
addressed code review
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Jan 31, 2015
1 parent 2cd6562 commit 231f72f
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 138 deletions.
2 changes: 1 addition & 1 deletion bin/windows-utils.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--p
SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>"
SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>
SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"

echo %1 | findstr %opts% >nul
if %ERRORLEVEL% equ 0 (
Expand Down
215 changes: 128 additions & 87 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ object SparkSubmit {

// Resolve maven dependencies if there are any and add classpath to jars
val resolvedMavenCoordinates =
SparkSubmitUtils.resolveMavenCoordinates(args.packages, args.repositories, args.ivyRepoPath)
SparkSubmitUtils.resolveMavenCoordinates(
args.packages, Option(args.repositories), Option(args.ivyRepoPath))
if (!resolvedMavenCoordinates.trim.isEmpty) {
if (args.jars == null || args.jars.trim.isEmpty) {
args.jars = resolvedMavenCoordinates
Expand Down Expand Up @@ -461,10 +462,6 @@ object SparkSubmit {
/** Provides utility functions to be used inside SparkSubmit. */
private[spark] object SparkSubmitUtils extends Logging {

// Directories for caching downloads through ivy and storing the jars when maven coordinates are
// supplied to spark-submit
private var PACKAGES_DIRECTORY: File = null

/**
* Represents a Maven Coordinate
* @param groupId the groupId of the coordinate
Expand All @@ -473,6 +470,95 @@ private[spark] object SparkSubmitUtils extends Logging {
*/
private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String)

/**
* Extracts maven coordinates from a comma-delimited string
* @param coordinates Comma-delimited string of maven coordinates
* @return Sequence of Maven coordinates
*/
private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = {
coordinates.split(",").map { p =>
val splits = p.split(":")
require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
s"'groupId:artifactId:version'. The coordinate provided is: $p")
require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
s"be whitespace. The groupId provided is: ${splits(0)}")
require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " +
s"be whitespace. The artifactId provided is: ${splits(1)}")
require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " +
s"be whitespace. The version provided is: ${splits(2)}")
new MavenCoordinate(splits(0), splits(1), splits(2))
}
}

/**
* Extracts maven coordinates from a comma-delimited string
* @param remoteRepos Comma-delimited string of remote repositories
* @return A ChainResolver used by Ivy to search for and resolve dependencies.
*/
private[spark] def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = {
// We need a chain resolver if we want to check multiple repositories
val cr = new ChainResolver
cr.setName("list")

// the biblio resolver resolves POM declared dependencies
val br: IBiblioResolver = new IBiblioResolver
br.setM2compatible(true)
br.setUsepoms(true)
br.setName("central")
cr.add(br)

val repositoryList = remoteRepos.getOrElse("")
// add any other remote repositories other than maven central
if (repositoryList.trim.nonEmpty) {
repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
val brr: IBiblioResolver = new IBiblioResolver
brr.setM2compatible(true)
brr.setUsepoms(true)
brr.setRoot(repo)
brr.setName(s"repo-${i + 1}")
cr.add(brr)
logInfo(s"$repo added as a remote repository with the name: ${brr.getName}")
}
}
cr
}

/**
* Output a comma-delimited list of paths for the downloaded jars to be added to the classpath
* (will append to jars in SparkSubmit). The name of the jar is given
* after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well.
* @param artifacts Sequence of dependencies that were resolved and retrieved
* @param cacheDirectory directory where jars are cached
* @return a comma-delimited list of paths for the dependencies
*/
private[spark] def resolveDependencyPaths(
artifacts: Array[AnyRef],
cacheDirectory: File): String = {
artifacts.map { case artifactInfo: MDArtifact =>
val artifactString = artifactInfo.toString
val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1)
cacheDirectory.getAbsolutePath + "/" + jarName.substring(0, jarName.lastIndexOf(".jar") + 4)
}.mkString(",")
}

/** Adds the given maven coordinates to Ivy's module descriptor. */
private[spark] def addDependenciesToIvy(
md: DefaultModuleDescriptor,
artifacts: Seq[MavenCoordinate],
ivyConfName: String): Unit = {
artifacts.foreach { mvn =>
val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
val dd = new DefaultDependencyDescriptor(ri, false, false)
dd.addDependencyConfiguration(ivyConfName, ivyConfName)
logInfo(s"${dd.getDependencyId} added as a dependency")
md.addDependency(dd)
}
}

/** A nice function to use in tests as well. Values are dummy strings. */
private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-envelope", "1.0"))

/**
* Resolves any dependencies that were supplied through maven coordinates
* @param coordinates Comma-delimited string of maven coordinates
Expand All @@ -483,73 +569,36 @@ private[spark] object SparkSubmitUtils extends Logging {
*/
private[spark] def resolveMavenCoordinates(
coordinates: String,
remoteRepos: String,
ivyPath: String,
remoteRepos: Option[String],
ivyPath: Option[String],
isTest: Boolean = false): String = {
if (coordinates == null || coordinates.trim.isEmpty) {
""
} else {
val artifacts = coordinates.split(",").map { p =>
val splits = p.split(":")
require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
s"'groupId:artifactId:version'. The coordinate provided is: $p")
require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
s"be whitespace. The groupId provided is: ${splits(0)}")
require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " +
s"be whitespace. The artifactId provided is: ${splits(1)}")
require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " +
s"be whitespace. The version provided is: ${splits(2)}")
new MavenCoordinate(splits(0), splits(1), splits(2))
}
val artifacts = extractMavenCoordinates(coordinates)
// Default configuration name for ivy
val conf = "default"
val ivyConfName = "default"
// set ivy settings for location of cache
val ivySettings: IvySettings = new IvySettings
if (ivyPath == null || ivyPath.trim.isEmpty) {
PACKAGES_DIRECTORY = new File(ivySettings.getDefaultIvyUserDir, "jars")
} else {
ivySettings.setDefaultCache(new File(ivyPath, "cache"))
PACKAGES_DIRECTORY = new File(ivyPath, "jars")
}
// Directories for caching downloads through ivy and storing the jars when maven coordinates are
// supplied to spark-submit
val alternateIvyCache = ivyPath.getOrElse("")
val packagesDirectory: File =
if (alternateIvyCache.trim.isEmpty) {
new File(ivySettings.getDefaultIvyUserDir, "jars")
} else {
ivySettings.setDefaultCache(new File(alternateIvyCache, "cache"))
new File(alternateIvyCache, "jars")
}
logInfo(s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}")
logInfo(s"The jars for the packages stored in: $PACKAGES_DIRECTORY")

logInfo(s"The jars for the packages stored in: $packagesDirectory")
// create a pattern matcher
ivySettings.addMatcher(new GlobPatternMatcher)
// create the dependency resolvers
val repoResolver = createRepoResolvers(remoteRepos)
ivySettings.addResolver(repoResolver)
ivySettings.setDefaultResolver(repoResolver.getName)

// the biblio resolver resolves POM declared dependencies
val br: IBiblioResolver = new IBiblioResolver
br.setM2compatible(true)
br.setUsepoms(true)
br.setName("central")

// We need a chain resolver if we want to check multiple repositories
val cr = new ChainResolver
cr.setName("list")
cr.add(br)

// Add an exclusion rule for Spark
val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
val sparkDependencyExcludeRule =
new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
sparkDependencyExcludeRule.addConfiguration(conf)

// add any other remote repositories other than maven central
if (remoteRepos != null && remoteRepos.trim.nonEmpty) {
var i = 1
remoteRepos.split(",").foreach { repo =>
val brr: IBiblioResolver = new IBiblioResolver
brr.setM2compatible(true)
brr.setUsepoms(true)
brr.setRoot(repo)
brr.setName(s"repo-$i")
cr.add(brr)
logInfo(s"$repo added as a remote repository with the name: ${brr.getName}")
i += 1
}
}
ivySettings.addResolver(cr)
ivySettings.setDefaultResolver(cr.getName)
val ivy = Ivy.newInstance(ivySettings)
// Set resolve options to download transitive dependencies as well
val resolveOptions = new ResolveOptions
Expand All @@ -565,19 +614,18 @@ private[spark] object SparkSubmitUtils extends Logging {
}

// A Module descriptor must be specified. Entries are dummy strings
val md = DefaultModuleDescriptor.newDefaultInstance(
ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-envelope", "1.0"))
md.setDefaultConf(conf)
val md = getModuleDescriptor
md.setDefaultConf(ivyConfName)

md.addExcludeRule(sparkDependencyExcludeRule)
// Add an exclusion rule for Spark
val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
val sparkDependencyExcludeRule =
new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
sparkDependencyExcludeRule.addConfiguration(ivyConfName)

artifacts.foreach { mvn =>
val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
val dd = new DefaultDependencyDescriptor(ri, false, false)
dd.addDependencyConfiguration(conf, conf)
logInfo(s"${dd.getDependencyId} added as a dependency")
md.addDependency(dd)
}
// Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies
md.addExcludeRule(sparkDependencyExcludeRule)
addDependenciesToIvy(md, artifacts, ivyConfName)

// resolve dependencies
val rr: ResolveReport = ivy.resolve(md, resolveOptions)
Expand All @@ -586,23 +634,16 @@ private[spark] object SparkSubmitUtils extends Logging {
}
// Log the callers for each dependency
rr.getDependencies.toArray.foreach { case dependency: IvyNode =>
logInfo(s"$dependency will be retrieved as a dependency for:")
dependency.getAllCallers.foreach (caller => logInfo(s"\t$caller"))
var logMsg = s"$dependency will be retrieved as a dependency for:"
dependency.getAllCallers.foreach (caller => logMsg += s"\n\t$caller")
logInfo(logMsg)
}
// retrieve all resolved dependencies
val m = rr.getModuleDescriptor
ivy.retrieve(m.getModuleRevisionId,
PACKAGES_DIRECTORY.getAbsolutePath + "/[artifact](-[classifier]).[ext]",
retrieveOptions.setConfs(Array(conf)))

// output downloaded jars to classpath (will append to jars). The name of the jar is given
// after a '!' by Ivy. It also sometimes contains (bundle) after '.jar'. Remove that as well.
rr.getArtifacts.toArray.map { case artifactInfo: MDArtifact =>
val artifactString = artifactInfo.toString
val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1)
PACKAGES_DIRECTORY.getAbsolutePath + "/" +
jarName.substring(0, jarName.lastIndexOf(".jar") + 4)
}.mkString(",")
ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId,
packagesDirectory.getAbsolutePath + "/[artifact](-[classifier]).[ext]",
retrieveOptions.setConfs(Array(ivyConfName)))

resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
}
}
}
Expand Down
36 changes: 6 additions & 30 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -307,20 +307,21 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
"--name", "testApp",
"--master", "local-cluster[2,1,512]",
"--jars", jarsString,
unusedJar.toString)
unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB")
runSparkSubmit(args)
}

test("includes jars passed in through --packages") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1"
val args = Seq(
"--class", MavenArtifactDownloadTest.getClass.getName.stripSuffix("$"),
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
"--master", "local-cluster[2,1,512]",
"--packages", packagesString,
"--conf", "spark.ui.enabled=false",
unusedJar.toString)
unusedJar.toString,
"com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource")
runSparkSubmit(args)
}

Expand Down Expand Up @@ -480,33 +481,8 @@ object JarCreationTest extends Logging {
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
var exception: String = null
try {
Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader)
Class.forName("SparkSubmitClassB", true, Thread.currentThread().getContextClassLoader)
} catch {
case t: Throwable =>
exception = t + "\n" + t.getStackTraceString
exception = exception.replaceAll("\n", "\n\t")
}
Option(exception).toSeq.iterator
}.collect()
if (result.nonEmpty) {
throw new Exception("Could not load user class from jar:\n" + result(0))
}
}
}

object MavenArtifactDownloadTest extends Logging {
def main(args: Array[String]) {
Utils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
var exception: String = null
try {
Class.forName("com.databricks.spark.csv.DefaultSource",
true, Thread.currentThread().getContextClassLoader)
Class.forName("com.databricks.spark.avro.DefaultSource",
true, Thread.currentThread().getContextClassLoader)
Class.forName(args(0), true, Thread.currentThread().getContextClassLoader)
Class.forName(args(1), true, Thread.currentThread().getContextClassLoader)
} catch {
case t: Throwable =>
exception = t + "\n" + t.getStackTraceString
Expand Down
Loading

0 comments on commit 231f72f

Please sign in to comment.