Skip to content

Commit

Permalink
[SPARK-529][CORE][YARN] Add type-safe config keys to SparkConf.
Browse files Browse the repository at this point in the history
This is, in a way, the basics to enable SPARK-529 (which was closed as
won't fix but I think is still valuable). In fact, Spark SQL created
something for that, and this change basically factors out that code
and inserts it into SparkConf, with some extra bells and whistles.

To showcase the usage of this pattern, I modified the YARN backend
to use the new config keys (defined in the new `config` package object
under `o.a.s.deploy.yarn`). Most of the changes are mechanic, although
logic had to be slightly modified in a handful of places.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #10205 from vanzin/conf-opts.
  • Loading branch information
Marcelo Vanzin committed Mar 7, 2016
1 parent e9e67b3 commit e1fb857
Show file tree
Hide file tree
Showing 20 changed files with 1,019 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ private static boolean isSymlink(File file) throws IOException {
.build();

/**
* Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for
* internal use. If no suffix is provided a direct conversion is attempted.
* Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count in the given unit.
* The unit is also considered the default if the given string does not specify a unit.
*/
private static long parseTimeString(String str, TimeUnit unit) {
public static long timeStringAs(String str, TimeUnit unit) {
String lower = str.toLowerCase().trim();

try {
Expand Down Expand Up @@ -195,23 +195,22 @@ private static long parseTimeString(String str, TimeUnit unit) {
* no suffix is provided, the passed number is assumed to be in ms.
*/
public static long timeStringAsMs(String str) {
return parseTimeString(str, TimeUnit.MILLISECONDS);
return timeStringAs(str, TimeUnit.MILLISECONDS);
}

/**
* Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If
* no suffix is provided, the passed number is assumed to be in seconds.
*/
public static long timeStringAsSec(String str) {
return parseTimeString(str, TimeUnit.SECONDS);
return timeStringAs(str, TimeUnit.SECONDS);
}

/**
* Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
* internal use. If no suffix is provided a direct conversion of the provided default is
* attempted.
* Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to the given. If no suffix is
* provided, a direct conversion to the provided unit is attempted.
*/
private static long parseByteString(String str, ByteUnit unit) {
public static long byteStringAs(String str, ByteUnit unit) {
String lower = str.toLowerCase().trim();

try {
Expand Down Expand Up @@ -252,7 +251,7 @@ private static long parseByteString(String str, ByteUnit unit) {
* If no suffix is provided, the passed number is assumed to be in bytes.
*/
public static long byteStringAsBytes(String str) {
return parseByteString(str, ByteUnit.BYTE);
return byteStringAs(str, ByteUnit.BYTE);
}

/**
Expand All @@ -262,7 +261,7 @@ public static long byteStringAsBytes(String str) {
* If no suffix is provided, the passed number is assumed to be in kibibytes.
*/
public static long byteStringAsKb(String str) {
return parseByteString(str, ByteUnit.KiB);
return byteStringAs(str, ByteUnit.KiB);
}

/**
Expand All @@ -272,7 +271,7 @@ public static long byteStringAsKb(String str) {
* If no suffix is provided, the passed number is assumed to be in mebibytes.
*/
public static long byteStringAsMb(String str) {
return parseByteString(str, ByteUnit.MiB);
return byteStringAs(str, ByteUnit.MiB);
}

/**
Expand All @@ -282,7 +281,7 @@ public static long byteStringAsMb(String str) {
* If no suffix is provided, the passed number is assumed to be in gibibytes.
*/
public static long byteStringAsGb(String str) {
return parseByteString(str, ByteUnit.GiB);
return byteStringAs(str, ByteUnit.GiB);
}

/**
Expand Down
39 changes: 38 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import scala.collection.JavaConverters._
import scala.collection.mutable.LinkedHashSet

import org.apache.avro.{Schema, SchemaNormalization}

import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -74,6 +76,16 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
this
}

private[spark] def set[T](entry: ConfigEntry[T], value: T): SparkConf = {
set(entry.key, entry.stringConverter(value))
this
}

private[spark] def set[T](entry: OptionalConfigEntry[T], value: T): SparkConf = {
set(entry.key, entry.rawStringConverter(value))
this
}

/**
* The master URL to connect to, such as "local" to run locally with one thread, "local[4]" to
* run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster.
Expand Down Expand Up @@ -148,6 +160,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
this
}

private[spark] def setIfMissing[T](entry: ConfigEntry[T], value: T): SparkConf = {
if (settings.putIfAbsent(entry.key, entry.stringConverter(value)) == null) {
logDeprecationWarning(entry.key)
}
this
}

private[spark] def setIfMissing[T](entry: OptionalConfigEntry[T], value: T): SparkConf = {
if (settings.putIfAbsent(entry.key, entry.rawStringConverter(value)) == null) {
logDeprecationWarning(entry.key)
}
this
}

/**
* Use Kryo serialization and register the given set of classes with Kryo.
* If called multiple times, this will append the classes from all calls together.
Expand Down Expand Up @@ -198,6 +224,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
getOption(key).getOrElse(defaultValue)
}

/**
* Retrieves the value of a pre-defined configuration entry.
*
* - This is an internal Spark API.
* - The return type if defined by the configuration entry.
* - This will throw an exception is the config is not optional and the value is not set.
*/
private[spark] def get[T](entry: ConfigEntry[T]): T = {
entry.readFrom(this)
}

/**
* Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no
* suffix is provided then seconds are assumed.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.internal.config

import java.util.concurrent.TimeUnit

import org.apache.spark.network.util.{ByteUnit, JavaUtils}

private object ConfigHelpers {

def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = {
try {
converter(s)
} catch {
case _: NumberFormatException =>
throw new IllegalArgumentException(s"$key should be $configType, but was $s")
}
}

def toBoolean(s: String, key: String): Boolean = {
try {
s.toBoolean
} catch {
case _: IllegalArgumentException =>
throw new IllegalArgumentException(s"$key should be boolean, but was $s")
}
}

def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter)
}

def seqToString[T](v: Seq[T], stringConverter: T => String): String = {
v.map(stringConverter).mkString(",")
}

def timeFromString(str: String, unit: TimeUnit): Long = JavaUtils.timeStringAs(str, unit)

def timeToString(v: Long, unit: TimeUnit): String = TimeUnit.MILLISECONDS.convert(v, unit) + "ms"

def byteFromString(str: String, unit: ByteUnit): Long = {
val (input, multiplier) =
if (str.length() > 0 && str.charAt(0) == '-') {
(str.substring(1), -1)
} else {
(str, 1)
}
multiplier * JavaUtils.byteStringAs(input, unit)
}

def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b"

}

/**
* A type-safe config builder. Provides methods for transforming the input data (which can be
* used, e.g., for validation) and creating the final config entry.
*
* One of the methods that return a [[ConfigEntry]] must be called to create a config entry that
* can be used with [[SparkConf]].
*/
private[spark] class TypedConfigBuilder[T](
val parent: ConfigBuilder,
val converter: String => T,
val stringConverter: T => String) {

import ConfigHelpers._

def this(parent: ConfigBuilder, converter: String => T) = {
this(parent, converter, Option(_).map(_.toString).orNull)
}

def transform(fn: T => T): TypedConfigBuilder[T] = {
new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter)
}

def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = {
transform { v =>
if (!validValues.contains(v)) {
throw new IllegalArgumentException(
s"The value of ${parent.key} should be one of ${validValues.mkString(", ")}, but was $v")
}
v
}
}

def toSequence: TypedConfigBuilder[Seq[T]] = {
new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter))
}

/** Creates a [[ConfigEntry]] that does not require a default value. */
def optional: OptionalConfigEntry[T] = {
new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, parent._public)
}

/** Creates a [[ConfigEntry]] that has a default value. */
def withDefault(default: T): ConfigEntry[T] = {
val transformedDefault = converter(stringConverter(default))
new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, stringConverter,
parent._doc, parent._public)
}

/**
* Creates a [[ConfigEntry]] that has a default value. The default value is provided as a
* [[String]] and must be a valid value for the entry.
*/
def withDefaultString(default: String): ConfigEntry[T] = {
val typedDefault = converter(default)
new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, parent._doc,
parent._public)
}

}

/**
* Basic builder for Spark configurations. Provides methods for creating type-specific builders.
*
* @see TypedConfigBuilder
*/
private[spark] case class ConfigBuilder(key: String) {

import ConfigHelpers._

var _public = true
var _doc = ""

def internal: ConfigBuilder = {
_public = false
this
}

def doc(s: String): ConfigBuilder = {
_doc = s
this
}

def intConf: TypedConfigBuilder[Int] = {
new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int"))
}

def longConf: TypedConfigBuilder[Long] = {
new TypedConfigBuilder(this, toNumber(_, _.toLong, key, "long"))
}

def doubleConf: TypedConfigBuilder[Double] = {
new TypedConfigBuilder(this, toNumber(_, _.toDouble, key, "double"))
}

def booleanConf: TypedConfigBuilder[Boolean] = {
new TypedConfigBuilder(this, toBoolean(_, key))
}

def stringConf: TypedConfigBuilder[String] = {
new TypedConfigBuilder(this, v => v)
}

def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = {
new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit))
}

def bytesConf(unit: ByteUnit): TypedConfigBuilder[Long] = {
new TypedConfigBuilder(this, byteFromString(_, unit), byteToString(_, unit))
}

def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = {
new FallbackConfigEntry(key, _doc, _public, fallback)
}

}
Loading

0 comments on commit e1fb857

Please sign in to comment.