diff --git a/LICENSE b/LICENSE
index 8672be55eca3e..f9e412cade345 100644
--- a/LICENSE
+++ b/LICENSE
@@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo
(MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org)
(MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/)
(MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt)
- (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org)
+ (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org)
(MIT License) jquery (https://jquery.org/license/)
(MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs)
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 6feabf4189c2d..60702824acb46 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -169,8 +169,8 @@ setMethod("isLocal",
#'}
setMethod("showDF",
signature(x = "DataFrame"),
- function(x, numRows = 20) {
- s <- callJMethod(x@sdf, "showString", numToInt(numRows))
+ function(x, numRows = 20, truncate = TRUE) {
+ s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate)
cat(s)
})
diff --git a/core/pom.xml b/core/pom.xml
index 565437c4861a4..aee0d92620606 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -69,16 +69,6 @@
org.apache.hadoophadoop-client
-
-
- javax.servlet
- servlet-api
-
-
- org.codehaus.jackson
- jackson-mapper-asl
-
- org.apache.spark
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala
index 2cdc167f85af0..32df42d57dbd6 100644
--- a/core/src/main/scala/org/apache/spark/SSLOptions.scala
+++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala
@@ -17,7 +17,9 @@
package org.apache.spark
-import java.io.File
+import java.io.{File, FileInputStream}
+import java.security.{KeyStore, NoSuchAlgorithmException}
+import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory}
import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory}
import org.eclipse.jetty.util.ssl.SslContextFactory
@@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory
* @param trustStore a path to the trust-store file
* @param trustStorePassword a password to access the trust-store file
* @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java
- * @param enabledAlgorithms a set of encryption algorithms to use
+ * @param enabledAlgorithms a set of encryption algorithms that may be used
*/
private[spark] case class SSLOptions(
enabled: Boolean = false,
@@ -48,7 +50,8 @@ private[spark] case class SSLOptions(
trustStore: Option[File] = None,
trustStorePassword: Option[String] = None,
protocol: Option[String] = None,
- enabledAlgorithms: Set[String] = Set.empty) {
+ enabledAlgorithms: Set[String] = Set.empty)
+ extends Logging {
/**
* Creates a Jetty SSL context factory according to the SSL settings represented by this object.
@@ -63,7 +66,7 @@ private[spark] case class SSLOptions(
trustStorePassword.foreach(sslContextFactory.setTrustStorePassword)
keyPassword.foreach(sslContextFactory.setKeyManagerPassword)
protocol.foreach(sslContextFactory.setProtocol)
- sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*)
+ sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*)
Some(sslContextFactory)
} else {
@@ -94,7 +97,7 @@ private[spark] case class SSLOptions(
.withValue("akka.remote.netty.tcp.security.protocol",
ConfigValueFactory.fromAnyRef(protocol.getOrElse("")))
.withValue("akka.remote.netty.tcp.security.enabled-algorithms",
- ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq))
+ ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq))
.withValue("akka.remote.netty.tcp.enable-ssl",
ConfigValueFactory.fromAnyRef(true)))
} else {
@@ -102,6 +105,36 @@ private[spark] case class SSLOptions(
}
}
+ /*
+ * The supportedAlgorithms set is a subset of the enabledAlgorithms that
+ * are supported by the current Java security provider for this protocol.
+ */
+ private val supportedAlgorithms: Set[String] = {
+ var context: SSLContext = null
+ try {
+ context = SSLContext.getInstance(protocol.orNull)
+ /* The set of supported algorithms does not depend upon the keys, trust, or
+ rng, although they will influence which algorithms are eventually used. */
+ context.init(null, null, null)
+ } catch {
+ case npe: NullPointerException =>
+ logDebug("No SSL protocol specified")
+ context = SSLContext.getDefault
+ case nsa: NoSuchAlgorithmException =>
+ logDebug(s"No support for requested SSL protocol ${protocol.get}")
+ context = SSLContext.getDefault
+ }
+
+ val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet
+
+ // Log which algorithms we are discarding
+ (enabledAlgorithms &~ providerAlgorithms).foreach { cipher =>
+ logDebug(s"Discarding unsupported cipher $cipher")
+ }
+
+ enabledAlgorithms & providerAlgorithms
+ }
+
/** Returns a string representation of this SSLOptions with all the passwords masked. */
override def toString: String = s"SSLOptions{enabled=$enabled, " +
s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " +
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c7a7436462083..b3c3bf3746e18 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_dagScheduler = ds
}
+ /**
+ * A unique identifier for the Spark application.
+ * Its format depends on the scheduler implementation.
+ * (i.e.
+ * in case of local spark app something like 'local-1433865536131'
+ * in case of YARN something like 'application_1433865536131_34483'
+ * )
+ */
def applicationId: String = _applicationId
def applicationAttemptId: Option[String] = _applicationAttemptId
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 4dfa7325934ff..524676544d6f5 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -391,7 +391,7 @@ private[r] object RRDD {
}
private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
- val rCommand = "Rscript"
+ val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
val rOptions = "--vanilla"
val rExecScript = rLibDir + "/SparkR/worker/" + script
val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index abf222757a95b..b1d6ec209d62b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -756,6 +756,20 @@ private[spark] object SparkSubmitUtils {
val cr = new ChainResolver
cr.setName("list")
+ val repositoryList = remoteRepos.getOrElse("")
+ // add any other remote repositories other than maven central
+ if (repositoryList.trim.nonEmpty) {
+ repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
+ val brr: IBiblioResolver = new IBiblioResolver
+ brr.setM2compatible(true)
+ brr.setUsepoms(true)
+ brr.setRoot(repo)
+ brr.setName(s"repo-${i + 1}")
+ cr.add(brr)
+ printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
+ }
+ }
+
val localM2 = new IBiblioResolver
localM2.setM2compatible(true)
localM2.setRoot(m2Path.toURI.toString)
@@ -786,20 +800,6 @@ private[spark] object SparkSubmitUtils {
sp.setRoot("http://dl.bintray.com/spark-packages/maven")
sp.setName("spark-packages")
cr.add(sp)
-
- val repositoryList = remoteRepos.getOrElse("")
- // add any other remote repositories other than maven central
- if (repositoryList.trim.nonEmpty) {
- repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
- val brr: IBiblioResolver = new IBiblioResolver
- brr.setM2compatible(true)
- brr.setUsepoms(true)
- brr.setRoot(repo)
- brr.setName(s"repo-${i + 1}")
- cr.add(brr)
- printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
- }
- }
cr
}
@@ -922,6 +922,15 @@ private[spark] object SparkSubmitUtils {
// A Module descriptor must be specified. Entries are dummy strings
val md = getModuleDescriptor
+ // clear ivy resolution from previous launches. The resolution file is usually at
+ // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file
+ // leads to confusion with Ivy when the files can no longer be found at the repository
+ // declared in that file/
+ val mdId = md.getModuleRevisionId
+ val previousResolution = new File(ivySettings.getDefaultCache,
+ s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml")
+ if (previousResolution.exists) previousResolution.delete
+
md.setDefaultConf(ivyConfName)
// Add exclusion rules for Spark and Scala Library
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 19157af5b6f4d..a7fc749a2b0c6 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2333,3 +2333,36 @@ private[spark] class RedirectThread(
}
}
}
+
+/**
+ * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it
+ * in a circular buffer. The current contents of the buffer can be accessed using
+ * the toString method.
+ */
+private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
+ var pos: Int = 0
+ var buffer = new Array[Int](sizeInBytes)
+
+ def write(i: Int): Unit = {
+ buffer(pos) = i
+ pos = (pos + 1) % buffer.length
+ }
+
+ override def toString: String = {
+ val (end, start) = buffer.splitAt(pos)
+ val input = new java.io.InputStream {
+ val iterator = (start ++ end).iterator
+
+ def read(): Int = if (iterator.hasNext) iterator.next() else -1
+ }
+ val reader = new BufferedReader(new InputStreamReader(input))
+ val stringBuilder = new StringBuilder
+ var line = reader.readLine()
+ while (line != null) {
+ stringBuilder.append(line)
+ stringBuilder.append("\n")
+ line = reader.readLine()
+ }
+ stringBuilder.toString()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
index 376481ba541fa..25b79bce6ab98 100644
--- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
+import javax.net.ssl.SSLContext
import com.google.common.io.Files
import org.apache.spark.util.Utils
@@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath
val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath
+ // Pick two cipher suites that the provider knows about
+ val sslContext = SSLContext.getInstance("TLSv1.2")
+ sslContext.init(null, null, null)
+ val algorithms = sslContext
+ .getServerSocketFactory
+ .getDefaultCipherSuites
+ .take(2)
+ .toSet
+
val conf = new SparkConf
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", keyStorePath)
@@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA")
- conf.set("spark.ssl.protocol", "SSLv3")
+ conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(","))
+ conf.set("spark.ssl.protocol", "TLSv1.2")
val opts = SSLOptions.parse(conf, "spark.ssl")
@@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(opts.trustStorePassword === Some("password"))
assert(opts.keyStorePassword === Some("password"))
assert(opts.keyPassword === Some("password"))
- assert(opts.protocol === Some("SSLv3"))
- assert(opts.enabledAlgorithms ===
- Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"))
+ assert(opts.protocol === Some("TLSv1.2"))
+ assert(opts.enabledAlgorithms === algorithms)
}
test("test resolving property with defaults specified ") {
diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
index 1a099da2c6c8e..33270bec6247c 100644
--- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
+++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
@@ -25,6 +25,20 @@ object SSLSampleConfigs {
this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath
val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath
+ val enabledAlgorithms =
+ // A reasonable set of TLSv1.2 Oracle security provider suites
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
+ "TLS_RSA_WITH_AES_256_CBC_SHA256, " +
+ "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ // and their equivalent names in the IBM Security provider
+ "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
+ "SSL_RSA_WITH_AES_256_CBC_SHA256, " +
+ "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
+ "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256"
+
def sparkSSLConfig(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.ssl.enabled", "true")
@@ -33,9 +47,8 @@ object SSLSampleConfigs {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA")
- conf.set("spark.ssl.protocol", "TLSv1")
+ conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
+ conf.set("spark.ssl.protocol", "TLSv1.2")
conf
}
@@ -47,9 +60,8 @@ object SSLSampleConfigs {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA")
- conf.set("spark.ssl.protocol", "TLSv1")
+ conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
+ conf.set("spark.ssl.protocol", "TLSv1.2")
conf
}
diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
index e9b64aa82a17a..f34aefca4eb18 100644
--- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
@@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite {
test("ssl on setup") {
val conf = SSLSampleConfigs.sparkSSLConfig()
+ val expectedAlgorithms = Set(
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
+ "TLS_RSA_WITH_AES_256_CBC_SHA256",
+ "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
+ "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
+ "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
+ "SSL_RSA_WITH_AES_256_CBC_SHA256",
+ "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256",
+ "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
+ "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256")
val securityManager = new SecurityManager(conf)
@@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite {
assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyPassword === Some("password"))
- assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1"))
- assert(securityManager.fileServerSSLOptions.enabledAlgorithms ===
- Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))
+ assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2"))
+ assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms)
assert(securityManager.akkaSSLOptions.trustStore.isDefined === true)
assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore")
@@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite {
assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyPassword === Some("password"))
- assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1"))
- assert(securityManager.akkaSSLOptions.enabledAlgorithms ===
- Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))
+ assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2"))
+ assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms)
}
test("ssl off setup") {
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 357ed90be3f5c..2e05dec99b6bf 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -548,6 +548,7 @@ object JarCreationTest extends Logging {
if (result.nonEmpty) {
throw new Exception("Could not load user class from jar:\n" + result(0))
}
+ sc.stop()
}
}
@@ -573,6 +574,7 @@ object SimpleApplicationTest {
s"Master had $config=$masterValue but executor had $config=$executorValue")
}
}
+ sc.stop()
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index 12c40f0b7d658..c9b435a9228d3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -77,9 +77,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(resolver2.getResolvers.size() === 7)
val expected = repos.split(",").map(r => s"$r/")
resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) =>
- if (i > 3) {
- assert(resolver.getName === s"repo-${i - 3}")
- assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4))
+ if (i < 3) {
+ assert(resolver.getName === s"repo-${i + 1}")
+ assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i))
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index a61ea3918f46a..baa4c661cc21e 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -673,4 +673,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(!Utils.isInDirectory(nullFile, parentDir))
assert(!Utils.isInDirectory(nullFile, childFile3))
}
+
+ test("circular buffer") {
+ val buffer = new CircularBuffer(25)
+ val stream = new java.io.PrintStream(buffer, true, "UTF-8")
+
+ stream.println("test circular test circular test circular test circular test circular")
+ assert(buffer.toString === "t circular test circular\n")
+ }
}
diff --git a/dev/run-tests b/dev/run-tests
index a00d9f0c27639..257d1e8d50bb4 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -20,4 +20,4 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-exec python -u ./dev/run-tests.py
+exec python -u ./dev/run-tests.py "$@"
diff --git a/dev/run-tests.py b/dev/run-tests.py
index e5c897b94d167..4596e07014733 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import itertools
+from optparse import OptionParser
import os
import re
import sys
@@ -360,12 +361,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules):
run_scala_tests_sbt(test_modules, test_profiles)
-def run_python_tests(test_modules):
+def run_python_tests(test_modules, parallelism):
set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS")
command = [os.path.join(SPARK_HOME, "python", "run-tests")]
if test_modules != [modules.root]:
command.append("--modules=%s" % ','.join(m.name for m in test_modules))
+ command.append("--parallelism=%i" % parallelism)
run_cmd(command)
@@ -379,7 +381,25 @@ def run_sparkr_tests():
print("Ignoring SparkR tests as R was not found in PATH")
+def parse_opts():
+ parser = OptionParser(
+ prog="run-tests"
+ )
+ parser.add_option(
+ "-p", "--parallelism", type="int", default=4,
+ help="The number of suites to test in parallel (default %default)"
+ )
+
+ (opts, args) = parser.parse_args()
+ if args:
+ parser.error("Unsupported arguments: %s" % ' '.join(args))
+ if opts.parallelism < 1:
+ parser.error("Parallelism cannot be less than 1")
+ return opts
+
+
def main():
+ opts = parse_opts()
# Ensure the user home directory (HOME) is valid and is an absolute directory
if not USER_HOME or not os.path.isabs(USER_HOME):
print("[error] Cannot determine your home directory as an absolute path;",
@@ -461,7 +481,7 @@ def main():
modules_with_python_tests = [m for m in test_modules if m.python_test_goals]
if modules_with_python_tests:
- run_python_tests(modules_with_python_tests)
+ run_python_tests(modules_with_python_tests, opts.parallelism)
if any(m.should_run_r_tests for m in test_modules):
run_sparkr_tests()
diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py
index ad9b0cc89e4ab..12bd0bf3a4fe9 100644
--- a/dev/sparktestsupport/shellutils.py
+++ b/dev/sparktestsupport/shellutils.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+from __future__ import print_function
import os
import shutil
import subprocess
diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py
new file mode 100644
index 0000000000000..55afe1b207fe0
--- /dev/null
+++ b/examples/src/main/python/ml/logistic_regression.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.evaluation import MulticlassMetrics
+from pyspark.ml.feature import StringIndexer
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating a logistic regression with elastic net regularization Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/logistic_regression.py
+"""
+
+if __name__ == "__main__":
+
+ if len(sys.argv) > 1:
+ print("Usage: logistic_regression", file=sys.stderr)
+ exit(-1)
+
+ sc = SparkContext(appName="PythonLogisticRegressionExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [training, test] = td.randomSplit([0.7, 0.3])
+
+ lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel")
+ lr.setElasticNetParam(0.8)
+
+ # Fit the model
+ lrModel = lr.fit(training)
+
+ predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+ sc.stop()
diff --git a/launcher/pom.xml b/launcher/pom.xml
index a853e67f5cf78..2fd768d8119c4 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -68,12 +68,6 @@
org.apache.hadoophadoop-clienttest
-
-
- org.codehaus.jackson
- jackson-mapper-asl
-
-
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 5a6265ea992c6..bc6eeac1db5da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -36,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
- /**
- * Here is the instruction describing how to export the test data into CSV format
- * so we can validate the training accuracy compared with R's glmnet package.
- *
- * import org.apache.spark.mllib.classification.LogisticRegressionSuite
- * val nPoints = 10000
- * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
- * val xMean = Array(5.843, 3.057, 3.758, 1.199)
- * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
- * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
- * weights, xMean, xVariance, true, nPoints, 42), 1)
- * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
- * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
+ /*
+ Here is the instruction describing how to export the test data into CSV format
+ so we can validate the training accuracy compared with R's glmnet package.
+
+ import org.apache.spark.mllib.classification.LogisticRegressionSuite
+ val nPoints = 10000
+ val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42), 1)
+ data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
+ + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
*/
binaryDataset = {
val nPoints = 10000
@@ -211,22 +211,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = (new LogisticRegression).setFitIntercept(true)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 2.8366423
- * data.V2 -0.5895848
- * data.V3 0.8931147
- * data.V4 -0.3925051
- * data.V5 -0.7996864
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 2.8366423
+ data.V2 -0.5895848
+ data.V3 0.8931147
+ data.V4 -0.3925051
+ data.V5 -0.7996864
*/
val interceptR = 2.8366423
val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
@@ -242,23 +242,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = (new LogisticRegression).setFitIntercept(false)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights =
- * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * data.V2 -0.3534996
- * data.V3 1.2964482
- * data.V4 -0.3571741
- * data.V5 -0.7407946
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights =
+ coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 -0.3534996
+ data.V3 1.2964482
+ data.V4 -0.3571741
+ data.V5 -0.7407946
*/
val interceptR = 0.0
val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
@@ -275,22 +275,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(1.0).setRegParam(0.12)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) -0.05627428
- * data.V2 .
- * data.V3 .
- * data.V4 -0.04325749
- * data.V5 -0.02481551
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) -0.05627428
+ data.V2 .
+ data.V3 .
+ data.V4 -0.04325749
+ data.V5 -0.02481551
*/
val interceptR = -0.05627428
val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
@@ -307,23 +307,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(1.0).setRegParam(0.12)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
- * intercept=FALSE))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * data.V2 .
- * data.V3 .
- * data.V4 -0.05189203
- * data.V5 -0.03891782
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+ intercept=FALSE))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 .
+ data.V3 .
+ data.V4 -0.05189203
+ data.V5 -0.03891782
*/
val interceptR = 0.0
val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
@@ -340,22 +340,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(0.0).setRegParam(1.37)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 0.15021751
- * data.V2 -0.07251837
- * data.V3 0.10724191
- * data.V4 -0.04865309
- * data.V5 -0.10062872
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 0.15021751
+ data.V2 -0.07251837
+ data.V3 0.10724191
+ data.V4 -0.04865309
+ data.V5 -0.10062872
*/
val interceptR = 0.15021751
val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
@@ -372,23 +372,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(0.0).setRegParam(1.37)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
- * intercept=FALSE))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * data.V2 -0.06099165
- * data.V3 0.12857058
- * data.V4 -0.04708770
- * data.V5 -0.09799775
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+ intercept=FALSE))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 -0.06099165
+ data.V3 0.12857058
+ data.V4 -0.04708770
+ data.V5 -0.09799775
*/
val interceptR = 0.0
val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
@@ -405,22 +405,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(0.38).setRegParam(0.21)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 0.57734851
- * data.V2 -0.05310287
- * data.V3 .
- * data.V4 -0.08849250
- * data.V5 -0.15458796
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 0.57734851
+ data.V2 -0.05310287
+ data.V3 .
+ data.V4 -0.08849250
+ data.V5 -0.15458796
*/
val interceptR = 0.57734851
val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
@@ -437,23 +437,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setElasticNetParam(0.38).setRegParam(0.21)
val model = trainer.fit(binaryDataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
- * intercept=FALSE))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * data.V2 -0.001005743
- * data.V3 0.072577857
- * data.V4 -0.081203769
- * data.V5 -0.142534158
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+ intercept=FALSE))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ data.V2 -0.001005743
+ data.V3 0.072577857
+ data.V4 -0.081203769
+ data.V5 -0.142534158
*/
val interceptR = 0.0
val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
@@ -480,16 +480,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
classSummarizer1.merge(classSummarizer2)
}).histogram
- /**
- * For binary logistic regression with strong L1 regularization, all the weights will be zeros.
- * As a result,
- * {{{
- * P(0) = 1 / (1 + \exp(b)), and
- * P(1) = \exp(b) / (1 + \exp(b))
- * }}}, hence
- * {{{
- * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
- * }}}
+ /*
+ For binary logistic regression with strong L1 regularization, all the weights will be zeros.
+ As a result,
+ {{{
+ P(0) = 1 / (1 + \exp(b)), and
+ P(1) = \exp(b) / (1 + \exp(b))
+ }}}, hence
+ {{{
+ b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+ }}}
*/
val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
@@ -500,22 +500,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * > library("glmnet")
- * > data <- read.csv("path", header=FALSE)
- * > label = factor(data$V1)
- * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
- * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
- * > weights
- * 5 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) -0.2480643
- * data.V2 0.0000000
- * data.V3 .
- * data.V4 .
- * data.V5 .
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ > library("glmnet")
+ > data <- read.csv("path", header=FALSE)
+ > label = factor(data$V1)
+ > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
+ > weights
+ 5 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) -0.2480643
+ data.V2 0.0000000
+ data.V3 .
+ data.V4 .
+ data.V5 .
*/
val interceptR = -0.248065
val weightsR = Array(0.0, 0.0, 0.0, 0.0)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index ad1e9da692ee2..5f39d44f37352 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -28,26 +28,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var datasetWithoutIntercept: DataFrame = _
- /**
- * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
- * is the same as the one trained by R's glmnet package. The following instruction
- * describes how to reproduce the data in R.
- *
- * import org.apache.spark.mllib.util.LinearDataGenerator
- * val data =
- * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
- * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
- * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
- * .saveAsTextFile("path")
+ /*
+ In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
+ is the same as the one trained by R's glmnet package. The following instruction
+ describes how to reproduce the data in R.
+
+ import org.apache.spark.mllib.util.LinearDataGenerator
+ val data =
+ sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
+ Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
+ data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
+ .saveAsTextFile("path")
*/
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
- /**
- * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
- * training model without intercept
+ /*
+ datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
+ training model without intercept
*/
datasetWithoutIntercept = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
@@ -59,20 +59,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = new LinearRegression
val model = trainer.fit(dataset)
- /**
- * Using the following R code to load the data and train the model using glmnet package.
- *
- * library("glmnet")
- * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
- * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
- * label <- as.numeric(data$V1)
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 6.300528
- * as.numeric.data.V2. 4.701024
- * as.numeric.data.V3. 7.198257
+ /*
+ Using the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+ label <- as.numeric(data$V1)
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.300528
+ as.numeric.data.V2. 4.701024
+ as.numeric.data.V3. 7.198257
*/
val interceptR = 6.298698
val weightsR = Array(4.700706, 7.199082)
@@ -94,29 +94,29 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val model = trainer.fit(dataset)
val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
- * intercept = FALSE))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * as.numeric.data.V2. 6.995908
- * as.numeric.data.V3. 5.275131
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
+ intercept = FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.995908
+ as.numeric.data.V3. 5.275131
*/
val weightsR = Array(6.995908, 5.275131)
assert(model.intercept ~== 0 relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
- /**
- * Then again with the data with no intercept:
- * > weightsWithoutIntercept
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * as.numeric.data3.V2. 4.70011
- * as.numeric.data3.V3. 7.19943
+ /*
+ Then again with the data with no intercept:
+ > weightsWithoutIntercept
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data3.V2. 4.70011
+ as.numeric.data3.V3. 7.19943
*/
val weightsWithoutInterceptR = Array(4.70011, 7.19943)
@@ -129,14 +129,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 6.24300
- * as.numeric.data.V2. 4.024821
- * as.numeric.data.V3. 6.679841
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.24300
+ as.numeric.data.V2. 4.024821
+ as.numeric.data.V3. 6.679841
*/
val interceptR = 6.24300
val weightsR = Array(4.024821, 6.679841)
@@ -158,15 +158,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setFitIntercept(false)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
- * intercept=FALSE))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * as.numeric.data.V2. 6.299752
- * as.numeric.data.V3. 4.772913
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ intercept=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.299752
+ as.numeric.data.V3. 4.772913
*/
val interceptR = 0.0
val weightsR = Array(6.299752, 4.772913)
@@ -187,14 +187,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 6.328062
- * as.numeric.data.V2. 3.222034
- * as.numeric.data.V3. 4.926260
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.328062
+ as.numeric.data.V2. 3.222034
+ as.numeric.data.V3. 4.926260
*/
val interceptR = 5.269376
val weightsR = Array(3.736216, 5.712356)
@@ -216,15 +216,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setFitIntercept(false)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
- * intercept = FALSE))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * as.numeric.data.V2. 5.522875
- * as.numeric.data.V3. 4.214502
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ intercept = FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.522875
+ as.numeric.data.V3. 4.214502
*/
val interceptR = 0.0
val weightsR = Array(5.522875, 4.214502)
@@ -245,14 +245,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) 6.324108
- * as.numeric.data.V2. 3.168435
- * as.numeric.data.V3. 5.200403
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.324108
+ as.numeric.data.V2. 3.168435
+ as.numeric.data.V3. 5.200403
*/
val interceptR = 5.696056
val weightsR = Array(3.670489, 6.001122)
@@ -274,15 +274,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.setFitIntercept(false)
val model = trainer.fit(dataset)
- /**
- * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
- * intercept=FALSE))
- * > weights
- * 3 x 1 sparse Matrix of class "dgCMatrix"
- * s0
- * (Intercept) .
- * as.numeric.dataM.V2. 5.673348
- * as.numeric.dataM.V3. 4.322251
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ intercept=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.dataM.V2. 5.673348
+ as.numeric.dataM.V3. 4.322251
*/
val interceptR = 0.0
val weightsR = Array(5.673348, 4.322251)
diff --git a/pom.xml b/pom.xml
index 4c18bd5e42c87..94dd512cfb618 100644
--- a/pom.xml
+++ b/pom.xml
@@ -747,6 +747,10 @@
asmasm
+
+ org.codehaus.jackson
+ jackson-mapper-asl
+ org.ow2.asmasm
@@ -759,6 +763,10 @@
commons-loggingcommons-logging
+
+ org.mockito
+ mockito-all
+ org.mortbay.jettyservlet-api-2.5
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 90b2fffbb9c7c..d7466729b8f36 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -291,6 +291,21 @@ def version(self):
"""
return self._jsc.version()
+ @property
+ @ignore_unicode_prefix
+ def applicationId(self):
+ """
+ A unique identifier for the Spark application.
+ Its format depends on the scheduler implementation.
+ (i.e.
+ in case of local spark app something like 'local-1433865536131'
+ in case of YARN something like 'application_1433865536131_34483'
+ )
+ >>> sc.applicationId # doctest: +ELLIPSIS
+ u'local-...'
+ """
+ return self._jsc.sc().applicationId()
+
@property
def startTime(self):
"""Return the epoch time when the Spark Context was started."""
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3cee4ea6e3a35..90cd342a6cf7f 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -51,6 +51,8 @@ def launch_gateway():
on_windows = platform.system() == "Windows"
script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ if os.environ.get("SPARK_TESTING"):
+ submit_args = "--conf spark.ui.enabled=false " + submit_args
command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
# Start a socket that will be used by PythonGatewayServer to communicate its port to us
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ddb33f427ac64..8804dace849b3 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -21,7 +21,7 @@
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
-__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder',
+__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
'Word2Vec', 'Word2VecModel']
@@ -265,6 +265,75 @@ class IDFModel(JavaModel):
"""
+@inherit_doc
+@ignore_unicode_prefix
+class NGram(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ A feature transformer that converts the input array of strings into an array of n-grams. Null
+ values in the input array are ignored.
+ It returns an array of n-grams where each n-gram is represented by a space-separated string of
+ words.
+ When the input is empty, an empty array is returned.
+ When the input array length is less than n (number of elements per n-gram), no n-grams are
+ returned.
+
+ >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
+ >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
+ >>> ngram.transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
+ >>> # Change n-gram length
+ >>> ngram.setParams(n=4).transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+ >>> # Temporarily modify output column.
+ >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
+ >>> ngram.transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+ >>> # Must use keyword arguments to specify params.
+ >>> ngram.setParams("text")
+ Traceback (most recent call last):
+ ...
+ TypeError: Method setParams forces keyword arguments.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
+
+ @keyword_only
+ def __init__(self, n=2, inputCol=None, outputCol=None):
+ """
+ __init__(self, n=2, inputCol=None, outputCol=None)
+ """
+ super(NGram, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
+ self.n = Param(self, "n", "number of elements per n-gram (>=1)")
+ self._setDefault(n=2)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, n=2, inputCol=None, outputCol=None):
+ """
+ setParams(self, n=2, inputCol=None, outputCol=None)
+ Sets params for this NGram.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setN(self, value):
+ """
+ Sets the value of :py:attr:`n`.
+ """
+ self._paramMap[self.n] = value
+ return self
+
+ def getN(self):
+ """
+ Gets the value of n or its default value.
+ """
+ return self.getOrDefault(self.n)
+
+
@inherit_doc
class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6adbf166f34a8..c151d21fd661a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -252,6 +252,17 @@ def test_idf(self):
output = idf0m.transform(dataset)
self.assertIsNotNone(output.head().idf)
+ def test_ngram(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ ([["a", "b", "c", "d", "e"]])], ["input"])
+ ngram0 = NGram(n=4, inputCol="input", outputCol="output")
+ self.assertEqual(ngram0.getN(), 4)
+ self.assertEqual(ngram0.getInputCol(), "input")
+ self.assertEqual(ngram0.getOutputCol(), "output")
+ transformedDF = ngram0.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index f00bb93b7bf40..b5138773fd61b 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -111,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
"""
def transform(self, vector):
+ """
+ Applies transformation on a vector or an RDD[Vector].
+
+ Note: In Python, transform cannot currently be used within
+ an RDD transformation or action.
+ Call transform directly on the RDD instead.
+
+ :param vector: Vector or RDD of Vector to be transformed.
+ """
if isinstance(vector, RDD):
vector = vector.map(_convert_to_vector)
else:
@@ -191,7 +200,7 @@ def fit(self, dataset):
Computes the mean and variance and stores as a model to be used
for later scaling.
- :param data: The data used to compute the mean and variance
+ :param dataset: The data used to compute the mean and variance
to build the transformation model.
:return: a StandardScalarModel
"""
@@ -346,10 +355,6 @@ def transform(self, x):
vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
- if isinstance(x, RDD):
- return JavaVectorTransformer.transform(self, x)
-
- x = _convert_to_vector(x)
return JavaVectorTransformer.transform(self, x)
def idf(self):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1b64be23a667e..cb20bc8b54027 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -121,10 +121,22 @@ def _parse_memory(s):
def _load_from_socket(port, serializer):
- sock = socket.socket()
- sock.settimeout(3)
+ sock = None
+ # Support for both IPv4 and IPv6.
+ # On most of IPv6-ready systems, IPv6 will take precedence.
+ for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ try:
+ sock = socket.socket(af, socktype, proto)
+ sock.settimeout(3)
+ sock.connect(sa)
+ except socket.error:
+ sock = None
+ continue
+ break
+ if not sock:
+ raise Exception("could not open socket")
try:
- sock.connect(("localhost", port))
rf = sock.makefile("rb", 65536)
for item in serializer.load_stream(rf):
yield item
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc239226e6d3c..4dda3b430cfbf 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -203,7 +203,37 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchemaFromList(self, data):
+ """
+ Infer schema from list of Row or tuple.
+
+ :param data: list of Row or tuple
+ :return: StructType
+ """
+ if not data:
+ raise ValueError("can not infer schema from empty dataset")
+ first = data[0]
+ if type(first) is dict:
+ warnings.warn("inferring schema from dict is deprecated,"
+ "please use pyspark.sql.Row instead")
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for r in data:
+ schema = _merge_type(schema, _infer_schema(r))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined after inferring")
+ return schema
+
def _inferSchema(self, rdd, samplingRatio=None):
+ """
+ Infer schema from an RDD of Row or tuple.
+
+ :param rdd: an RDD of Row or tuple
+ :param samplingRatio: sampling ratio, or no sampling (default)
+ :return: StructType
+ """
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
@@ -322,6 +352,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
data = [r.tolist() for r in data.to_records(index=False)]
if not isinstance(data, RDD):
+ if not isinstance(data, list):
+ data = list(data)
try:
# data could be list, tuple, generator ...
rdd = self._sc.parallelize(data)
@@ -330,28 +362,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
else:
rdd = data
- if schema is None:
- schema = self._inferSchema(rdd, samplingRatio)
+ if schema is None or isinstance(schema, (list, tuple)):
+ if isinstance(data, RDD):
+ struct = self._inferSchema(rdd, samplingRatio)
+ else:
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ schema = struct
converter = _create_converter(schema)
rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise TypeError("each row in `rdd` should be list or tuple, "
- "but got %r" % type(first))
- row_cls = Row(*schema)
- schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
-
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
- for row in rows:
- _verify_type(row, schema)
+ else:
+ raise TypeError("schema should be StructType or list or None")
# convert python objects to sql data
converter = _python_to_sql_converter(schema)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 152b87351db31..4b9efa0a210fb 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -247,9 +247,12 @@ def isLocal(self):
return self._jdf.isLocal()
@since(1.3)
- def show(self, n=20):
+ def show(self, n=20, truncate=True):
"""Prints the first ``n`` rows to the console.
+ :param n: Number of rows to show.
+ :param truncate: Whether truncate long strings and align cells right.
+
>>> df
DataFrame[age: int, name: string]
>>> df.show()
@@ -260,7 +263,7 @@ def show(self, n=20):
| 5| Bob|
+---+-----+
"""
- print(self._jdf.showString(n))
+ print(self._jdf.showString(n, truncate))
def __repr__(self):
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7d3d0361610b7..45ecd826bd3bd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -42,6 +42,7 @@
'monotonicallyIncreasingId',
'rand',
'randn',
+ 'sha1',
'sha2',
'sparkPartitionId',
'struct',
@@ -382,6 +383,19 @@ def sha2(col, numBits):
return Column(jc)
+@ignore_unicode_prefix
+@since(1.5)
+def sha1(col):
+ """Returns the hex string result of SHA-1.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+ [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.sha1(_to_java_column(col))
+ return Column(jc)
+
+
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ffee43a94baba..34f397d0ffef0 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -516,6 +516,35 @@ def test_between_function(self):
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c)).collect())
+ def test_struct_type(self):
+ from pyspark.sql.types import StructType, StringType, StructField
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ # Catch exception raised during improper construction
+ try:
+ struct1 = StructType().add("name")
+ self.assertEqual(1, 0)
+ except ValueError:
+ self.assertEqual(1, 1)
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 23d9adb0daea1..ae9344e6106a4 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -355,8 +355,7 @@ class StructType(DataType):
This is the data type representing a :class:`Row`.
"""
-
- def __init__(self, fields):
+ def __init__(self, fields=None):
"""
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True)])
@@ -368,8 +367,53 @@ def __init__(self, fields):
>>> struct1 == struct2
False
"""
- assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
- self.fields = fields
+ if not fields:
+ self.fields = []
+ else:
+ self.fields = fields
+ assert all(isinstance(f, StructField) for f in fields),\
+ "fields should be a list of StructField"
+
+ def add(self, field, data_type=None, nullable=True, metadata=None):
+ """
+ Construct a StructType by adding new elements to it to define the schema. The method accepts
+ either:
+ a) A single parameter which is a StructField object.
+ b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
+ metadata(optional). The data_type parameter may be either a String or a DataType object
+
+ >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ >>> struct2 = StructType([StructField("f1", StringType(), True),\
+ StructField("f2", StringType(), True, None)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType().add(StructField("f1", StringType(), True))
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType().add("f1", "string", True)
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
+ >>> struct1 == struct2
+ True
+
+ :param field: Either the name of the field or a StructField object
+ :param data_type: If present, the DataType of the StructField to create
+ :param nullable: Whether the field to add should be nullable (default True)
+ :param metadata: Any additional metadata (default None)
+ :return: a new updated StructType
+ """
+ if isinstance(field, StructField):
+ self.fields.append(field)
+ else:
+ if isinstance(field, str) and data_type is None:
+ raise ValueError("Must specify DataType if passing name of struct_field to create.")
+
+ if isinstance(data_type, str):
+ data_type_f = _parse_datatype_json_value(data_type)
+ else:
+ data_type_f = data_type
+ self.fields.append(StructField(field, data_type_f, nullable, metadata))
+ return self
def simpleString(self):
return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
@@ -635,7 +679,7 @@ def _need_python_to_sql_conversion(dataType):
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
- False
+ True
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
@@ -647,7 +691,8 @@ def _need_python_to_sql_conversion(dataType):
True
"""
if isinstance(dataType, StructType):
- return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ # convert namedtuple or Row into tuple
+ return True
elif isinstance(dataType, ArrayType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
@@ -688,21 +733,25 @@ def _python_to_sql_converter(dataType):
if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
- converters = [_python_to_sql_converter(t) for t in types]
-
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(c(obj.get(n)) for n, c in zip(names, converters))
- elif isinstance(obj, tuple):
- if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
- return tuple(c(v) for c, v in zip(converters, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
- d = dict(obj)
- return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ if any(_need_python_to_sql_conversion(t) for t in types):
+ converters = [_python_to_sql_converter(t) for t in types]
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif obj is not None:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ else:
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(obj.get(n) for n in names)
else:
- return tuple(c(v) for c, v in zip(converters, obj))
- elif obj is not None:
- raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return tuple(obj)
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
@@ -1027,10 +1076,13 @@ def _verify_type(obj, dataType):
_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s" % dataType
- # subclass of them can not be deserialized in JVM
- if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept object in type %s"
- % (dataType, type(obj)))
+ if _type is StructType:
+ if not isinstance(obj, (tuple, list)):
+ raise TypeError("StructType can not accept object in type %s" % type(obj))
+ else:
+ # subclass of them can not be deserialized in JVM
+ if type(obj) not in _acceptable_types[_type]:
+ raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
if isinstance(dataType, ArrayType):
for i in obj:
diff --git a/python/run-tests.py b/python/run-tests.py
index 7d485b500ee3a..b7737650daa54 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -18,12 +18,19 @@
#
from __future__ import print_function
+import logging
from optparse import OptionParser
import os
import re
import subprocess
import sys
+import tempfile
+from threading import Thread, Lock
import time
+if sys.version < '3':
+ import Queue
+else:
+ import queue as Queue
# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
@@ -43,34 +50,55 @@ def print_red(text):
LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
+FAILURE_REPORTING_LOCK = Lock()
+LOGGER = logging.getLogger()
def run_individual_python_test(test_name, pyspark_python):
env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
- print(" Running test: %s ..." % test_name, end='')
+ LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
- with open(LOG_FILE, 'a') as log_file:
- retcode = subprocess.call(
+ try:
+ per_test_output = tempfile.TemporaryFile()
+ retcode = subprocess.Popen(
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
- stderr=log_file, stdout=log_file, env=env)
+ stderr=per_test_output, stdout=per_test_output, env=env).wait()
+ except:
+ LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
+ # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+ # this code is invoked from a thread other than the main thread.
+ os._exit(1)
duration = time.time() - start_time
# Exit on the first failure.
if retcode != 0:
- with open(LOG_FILE, 'r') as log_file:
- for line in log_file:
- if not re.match('[0-9]+', line):
- print(line, end='')
- print_red("\nHad test failures in %s; see logs." % test_name)
- exit(-1)
+ try:
+ with FAILURE_REPORTING_LOCK:
+ with open(LOG_FILE, 'ab') as log_file:
+ per_test_output.seek(0)
+ log_file.writelines(per_test_output)
+ per_test_output.seek(0)
+ for line in per_test_output:
+ decoded_line = line.decode()
+ if not re.match('[0-9]+', decoded_line):
+ print(decoded_line, end='')
+ per_test_output.close()
+ except:
+ LOGGER.exception("Got an exception while trying to print failed test output")
+ finally:
+ print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
+ # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+ # this code is invoked from a thread other than the main thread.
+ os._exit(-1)
else:
- print("ok (%is)" % duration)
+ per_test_output.close()
+ LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
def get_default_python_executables():
python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
if "python2.6" not in python_execs:
- print("WARNING: Not testing against `python2.6` because it could not be found; falling"
- " back to `python` instead")
+ LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
+ " back to `python` instead")
python_execs.insert(0, "python")
return python_execs
@@ -88,16 +116,31 @@ def parse_opts():
default=",".join(sorted(python_modules.keys())),
help="A comma-separated list of Python modules to test (default: %default)"
)
+ parser.add_option(
+ "-p", "--parallelism", type="int", default=4,
+ help="The number of suites to test in parallel (default %default)"
+ )
+ parser.add_option(
+ "--verbose", action="store_true",
+ help="Enable additional debug logging"
+ )
(opts, args) = parser.parse_args()
if args:
parser.error("Unsupported arguments: %s" % ' '.join(args))
+ if opts.parallelism < 1:
+ parser.error("Parallelism cannot be less than 1")
return opts
def main():
opts = parse_opts()
- print("Running PySpark tests. Output is in python/%s" % LOG_FILE)
+ if (opts.verbose):
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+ logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
+ LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
@@ -108,24 +151,45 @@ def main():
else:
print("Error: unrecognized module %s" % module_name)
sys.exit(-1)
- print("Will test against the following Python executables: %s" % python_execs)
- print("Will test the following Python modules: %s" % [x.name for x in modules_to_test])
+ LOGGER.info("Will test against the following Python executables: %s", python_execs)
+ LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
- start_time = time.time()
+ task_queue = Queue.Queue()
for python_exec in python_execs:
python_implementation = subprocess.check_output(
[python_exec, "-c", "import platform; print(platform.python_implementation())"],
universal_newlines=True).strip()
- print("Testing with `%s`: " % python_exec, end='')
- subprocess.call([python_exec, "--version"])
-
+ LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
+ LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output(
+ [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
- print("Running %s tests ..." % module.name)
for test_goal in module.python_test_goals:
- run_individual_python_test(test_goal, python_exec)
+ task_queue.put((python_exec, test_goal))
+
+ def process_queue(task_queue):
+ while True:
+ try:
+ (python_exec, test_goal) = task_queue.get_nowait()
+ except Queue.Empty:
+ break
+ try:
+ run_individual_python_test(test_goal, python_exec)
+ finally:
+ task_queue.task_done()
+
+ start_time = time.time()
+ for _ in range(opts.parallelism):
+ worker = Thread(target=process_queue, args=(task_queue,))
+ worker.daemon = True
+ worker.start()
+ try:
+ task_queue.join()
+ except (KeyboardInterrupt, SystemExit):
+ print_red("Exiting due to interrupt")
+ sys.exit(-1)
total_duration = time.time() - start_time
- print("Tests passed in %i seconds" % total_duration)
+ LOGGER.info("Tests passed in %i seconds", total_duration)
if __name__ == "__main__":
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 83f2a312972fb..1e79f4b2e88e5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -19,9 +19,11 @@
import java.util.Iterator;
+import scala.Function1;
+
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.catalyst.util.UniqueObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -38,26 +40,48 @@ public final class UnsafeFixedWidthAggregationMap {
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
* map, we copy this buffer and use it as the value.
*/
- private final byte[] emptyAggregationBuffer;
+ private final byte[] emptyBuffer;
- private final StructType aggregationBufferSchema;
+ /**
+ * An empty row used by `initProjection`
+ */
+ private static final InternalRow emptyRow = new GenericInternalRow();
- private final StructType groupingKeySchema;
+ /**
+ * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
+ */
+ private final boolean reuseEmptyBuffer;
/**
- * Encodes grouping keys as UnsafeRows.
+ * The projection used to initialize the emptyBuffer
*/
- private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+ private final Function1 initProjection;
+
+ /**
+ * Encodes grouping keys or buffers as UnsafeRows.
+ */
+ private final UnsafeRowConverter keyConverter;
+ private final UnsafeRowConverter bufferConverter;
/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
*/
private final BytesToBytesMap map;
+ /**
+ * An object pool for objects that are used in grouping keys.
+ */
+ private final UniqueObjectPool keyPool;
+
+ /**
+ * An object pool for objects that are used in aggregation buffers.
+ */
+ private final ObjectPool bufferPool;
+
/**
* Re-used pointer to the current aggregation buffer
*/
- private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+ private final UnsafeRow currentBuffer = new UnsafeRow();
/**
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
@@ -69,68 +93,39 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
- /**
- * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
- * false otherwise.
- */
- public static boolean supportsGroupKeySchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
- /**
- * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
- * schema, false otherwise.
- */
- public static boolean supportsAggregationBufferSchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
/**
* Create a new UnsafeFixedWidthAggregationMap.
*
- * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
- * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
- * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param initProjection the default value for new keys (a "zero" of the agg. function)
+ * @param keyConverter the converter of the grouping key, used for row conversion.
+ * @param bufferConverter the converter of the aggregation buffer, used for row conversion.
* @param memoryManager the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
- InternalRow emptyAggregationBuffer,
- StructType aggregationBufferSchema,
- StructType groupingKeySchema,
+ Function1 initProjection,
+ UnsafeRowConverter keyConverter,
+ UnsafeRowConverter bufferConverter,
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
- this.emptyAggregationBuffer =
- convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
- this.aggregationBufferSchema = aggregationBufferSchema;
- this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
- this.groupingKeySchema = groupingKeySchema;
- this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.initProjection = initProjection;
+ this.keyConverter = keyConverter;
+ this.bufferConverter = bufferConverter;
this.enablePerfMetrics = enablePerfMetrics;
- }
- /**
- * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
- */
- private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
- final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
- final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)];
- final int writtenLength =
- converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET);
- assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
- return unsafeRow;
+ this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.keyPool = new UniqueObjectPool(100);
+ this.bufferPool = new ObjectPool(initialCapacity);
+
+ InternalRow initRow = initProjection.apply(emptyRow);
+ this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+ int writtenLength = bufferConverter.writeRow(
+ initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+ assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
+ // re-use the empty buffer only when there is no object saved in pool.
+ reuseEmptyBuffer = bufferPool.size() == 0;
}
/**
@@ -138,15 +133,16 @@ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema)
* return the same object.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
- final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
+ final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey);
// Make sure that the buffer is large enough to hold the key. If it's not, grow it:
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
groupingKeyConversionScratchSpace = new byte[groupingKeySize];
}
- final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
+ final int actualGroupingKeySize = keyConverter.writeRow(
groupingKey,
groupingKeyConversionScratchSpace,
- PlatformDependent.BYTE_ARRAY_OFFSET);
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ keyPool);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
// Probe our map using the serialized key
@@ -157,25 +153,31 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
+ if (!reuseEmptyBuffer) {
+ // There is some objects referenced by emptyBuffer, so generate a new one
+ InternalRow initRow = initProjection.apply(emptyRow);
+ bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ bufferPool);
+ }
loc.putNewKey(
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
groupingKeySize,
- emptyAggregationBuffer,
+ emptyBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
- emptyAggregationBuffer.length
+ emptyBuffer.length
);
}
// Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
- currentAggregationBuffer.pointTo(
+ currentBuffer.pointTo(
address.getBaseObject(),
address.getBaseOffset(),
- aggregationBufferSchema.length(),
- aggregationBufferSchema
+ bufferConverter.numFields(),
+ bufferPool
);
- return currentAggregationBuffer;
+ return currentBuffer;
}
/**
@@ -211,14 +213,14 @@ public MapEntry next() {
entry.key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
- groupingKeySchema.length(),
- groupingKeySchema
+ keyConverter.numFields(),
+ keyPool
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
- aggregationBufferSchema.length(),
- aggregationBufferSchema
+ bufferConverter.numFields(),
+ bufferPool
);
return entry;
}
@@ -246,6 +248,8 @@ public void printPerfMetrics() {
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ System.out.println("Number of unique objects in keys: " + keyPool.size());
+ System.out.println("Number of objects in buffers: " + bufferPool.size());
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 11d51d90f1802..f077064a02ec0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,20 +17,12 @@
package org.apache.spark.sql.catalyst.expressions;
-import javax.annotation.Nullable;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Set;
-
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.UTF8String;
-import static org.apache.spark.sql.types.DataTypes.*;
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -44,7 +36,20 @@
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
* base address of the row) that points to the beginning of the variable-length field, and length
- * (they are combined into a long).
+ * (they are combined into a long). For other objects, they are stored in a pool, the indexes of
+ * them are hold in the the word.
+ *
+ * In order to support fast hashing and equality checks for UnsafeRows that contain objects
+ * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make
+ * sure all the key have the same index for same object, then we can hash/compare the objects by
+ * hash/compare the index.
+ *
+ * For non-primitive types, the word of a field could be:
+ * UNION {
+ * [1] [offset: 31bits] [length: 31bits] // StringType
+ * [0] [offset: 31bits] [length: 31bits] // BinaryType
+ * - [index: 63bits] // StringType, Binary, index to object in pool
+ * }
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
@@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
+ /** A pool to hold non-primitive objects */
+ private ObjectPool pool;
+
Object getBaseObject() { return baseObject; }
long getBaseOffset() { return baseOffset; }
+ ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
@@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow {
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
- /**
- * This optional schema is required if you want to call generic get() and set() methods on
- * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE()
- * methods. This should be removed after the planned InternalRow / Row split; right now, it's only
- * needed by the generic get() method, which is only called internally by code that accesses
- * UTF8String-typed columns.
- */
- @Nullable
- private StructType schema;
private long getFieldOffset(int ordinal) {
return baseOffset + bitSetWidthInBytes + ordinal * 8L;
@@ -81,42 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
}
- /**
- * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
- */
- public static final Set settableFieldTypes;
-
- /**
- * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
- */
- public static final Set readableFieldTypes;
-
- // TODO: support DecimalType
- static {
- settableFieldTypes = Collections.unmodifiableSet(
- new HashSet(
- Arrays.asList(new DataType[] {
- NullType,
- BooleanType,
- ByteType,
- ShortType,
- IntegerType,
- LongType,
- FloatType,
- DoubleType,
- DateType,
- TimestampType
- })));
-
- // We support get() on a superset of the types for which we support set():
- final Set _readableFieldTypes = new HashSet(
- Arrays.asList(new DataType[]{
- StringType,
- BinaryType
- }));
- _readableFieldTypes.addAll(settableFieldTypes);
- readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
- }
+ public static final long OFFSET_BITS = 31L;
/**
* Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
@@ -130,22 +95,15 @@ public UnsafeRow() { }
* @param baseObject the base object
* @param baseOffset the offset within the base object
* @param numFields the number of fields in this row
- * @param schema an optional schema; this is necessary if you want to call generic get() or set()
- * methods on this row, but is optional if the caller will only use type-specific
- * getTYPE() and setTYPE() methods.
+ * @param pool the object pool to hold arbitrary objects
*/
- public void pointTo(
- Object baseObject,
- long baseOffset,
- int numFields,
- @Nullable StructType schema) {
+ public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
assert numFields >= 0 : "numFields should >= 0";
- assert schema == null || schema.fields().length == numFields;
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
- this.schema = schema;
+ this.pool = pool;
}
private void assertIndexIsValid(int index) {
@@ -168,9 +126,68 @@ private void setNotNullAt(int i) {
BitSetMethods.unset(baseObject, baseOffset, i);
}
+ /**
+ * Updates the column `i` as Object `value`, which cannot be primitive types.
+ */
@Override
- public void update(int ordinal, Object value) {
- throw new UnsupportedOperationException();
+ public void update(int i, Object value) {
+ if (value == null) {
+ if (!isNullAt(i)) {
+ // remove the old value from pool
+ long idx = getLong(i);
+ if (idx <= 0) {
+ // this is the index of old value in pool, remove it
+ pool.replace((int)-idx, null);
+ } else {
+ // there will be some garbage left (UTF8String or byte[])
+ }
+ setNullAt(i);
+ }
+ return;
+ }
+
+ if (isNullAt(i)) {
+ // there is not an old value, put the new value into pool
+ int idx = pool.put(value);
+ setLong(i, (long)-idx);
+ } else {
+ // there is an old value, check the type, then replace it or update it
+ long v = getLong(i);
+ if (v <= 0) {
+ // it's the index in the pool, replace old value with new one
+ int idx = (int)-v;
+ pool.replace(idx, value);
+ } else {
+ // old value is UTF8String or byte[], try to reuse the space
+ boolean isString;
+ byte[] newBytes;
+ if (value instanceof UTF8String) {
+ newBytes = ((UTF8String) value).getBytes();
+ isString = true;
+ } else {
+ newBytes = (byte[]) value;
+ isString = false;
+ }
+ int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+ int oldLength = (int) (v & Integer.MAX_VALUE);
+ if (newBytes.length <= oldLength) {
+ // the new value can fit in the old buffer, re-use it
+ PlatformDependent.copyMemory(
+ newBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ baseOffset + offset,
+ newBytes.length);
+ long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L;
+ setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length);
+ } else {
+ // Cannot fit in the buffer
+ int idx = pool.put(value);
+ setLong(i, (long) -idx);
+ }
+ }
+ }
+ setNotNullAt(i);
}
@Override
@@ -227,28 +244,38 @@ public int size() {
return numFields;
}
- @Override
- public StructType schema() {
- return schema;
- }
-
+ /**
+ * Returns the object for column `i`, which should not be primitive type.
+ */
@Override
public Object get(int i) {
assertIndexIsValid(i);
- assert (schema != null) : "Schema must be defined when calling generic get() method";
- final DataType dataType = schema.fields()[i].dataType();
- // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
- // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
- // separate the internal and external row interfaces, then internal code can fetch strings via
- // a new getUTF8String() method and we'll be able to remove this method.
if (isNullAt(i)) {
return null;
- } else if (dataType == StringType) {
- return getUTF8String(i);
- } else if (dataType == BinaryType) {
- return getBinary(i);
+ }
+ long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
+ if (v <= 0) {
+ // It's an index to object in the pool.
+ int idx = (int)-v;
+ return pool.get(idx);
} else {
- throw new UnsupportedOperationException();
+ // The column could be StingType or BinaryType
+ boolean isString = (v >> (OFFSET_BITS * 2)) > 0;
+ int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+ int size = (int) (v & Integer.MAX_VALUE);
+ final byte[] bytes = new byte[size];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset + offset,
+ bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ size
+ );
+ if (isString) {
+ return UTF8String.fromBytes(bytes);
+ } else {
+ return bytes;
+ }
}
}
@@ -308,31 +335,6 @@ public double getDouble(int i) {
}
}
- public UTF8String getUTF8String(int i) {
- return UTF8String.fromBytes(getBinary(i));
- }
-
- public byte[] getBinary(int i) {
- assertIndexIsValid(i);
- final long offsetAndSize = getLong(i);
- final int offset = (int)(offsetAndSize >> 32);
- final int size = (int)(offsetAndSize & ((1L << 32) - 1));
- final byte[] bytes = new byte[size];
- PlatformDependent.copyMemory(
- baseObject,
- baseOffset + offset,
- bytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- size
- );
- return bytes;
- }
-
- @Override
- public String getString(int i) {
- return getUTF8String(i).toString();
- }
-
@Override
public InternalRow copy() {
throw new UnsupportedOperationException();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
new file mode 100644
index 0000000000000..97f89a7d0b758
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+/**
+ * A object pool stores a collection of objects in array, then they can be referenced by the
+ * pool plus an index.
+ */
+public class ObjectPool {
+
+ /**
+ * An array to hold objects, which will grow as needed.
+ */
+ private Object[] objects;
+
+ /**
+ * How many objects in the pool.
+ */
+ private int numObj;
+
+ public ObjectPool(int capacity) {
+ objects = new Object[capacity];
+ numObj = 0;
+ }
+
+ /**
+ * Returns how many objects in the pool.
+ */
+ public int size() {
+ return numObj;
+ }
+
+ /**
+ * Returns the object at position `idx` in the array.
+ */
+ public Object get(int idx) {
+ assert (idx < numObj);
+ return objects[idx];
+ }
+
+ /**
+ * Puts an object `obj` at the end of array, returns the index of it.
+ *
+ * The array will grow as needed.
+ */
+ public int put(Object obj) {
+ if (numObj >= objects.length) {
+ Object[] tmp = new Object[objects.length * 2];
+ System.arraycopy(objects, 0, tmp, 0, objects.length);
+ objects = tmp;
+ }
+ objects[numObj++] = obj;
+ return numObj - 1;
+ }
+
+ /**
+ * Replaces the object at `idx` with new one `obj`.
+ */
+ public void replace(int idx, Object obj) {
+ assert (idx < numObj);
+ objects[idx] = obj;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
new file mode 100644
index 0000000000000..d512392dcaacc
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+import java.util.HashMap;
+
+/**
+ * An unique object pool stores a collection of unique objects in it.
+ */
+public class UniqueObjectPool extends ObjectPool {
+
+ /**
+ * A hash map from objects to their indexes in the array.
+ */
+ private HashMap