From 7d427222dca4807ec55e8d9a7de6ffe861cd0d24 Mon Sep 17 00:00:00 2001 From: Dan McClary Date: Fri, 1 May 2015 11:55:43 -0700 Subject: [PATCH 01/91] [SPARK-5854] personalized page rank Here's a modification to PageRank which does personalized PageRank. The approach is basically similar to that outlined by Bahmani et al. from 2010 (http://arxiv.org/pdf/1006.2880.pdf). I'm sure this needs tuning up or other considerations, so let me know how I can improve this. Author: Dan McClary Author: dwmclary Closes #4774 from dwmclary/SPARK-5854-Personalized-PageRank and squashes the following commits: 8b907db [dwmclary] fixed scalastyle errors in PageRankSuite 2c20e5d [dwmclary] merged with upstream master d6cebac [dwmclary] updated as per style requests 7d00c23 [Dan McClary] fixed line overrun in personalizedVertexPageRank d711677 [Dan McClary] updated vertexProgram to restore binary compatibility for inner method bb8d507 [Dan McClary] Merge branch 'master' of https://github.com/apache/spark into SPARK-5854-Personalized-PageRank fba0edd [Dan McClary] fixed silly mistakes de51be2 [Dan McClary] cleaned up whitespace between comments and methods 0c30d0c [Dan McClary] updated to maintain binary compatibility aaf0b4b [Dan McClary] Merge branch 'master' of https://github.com/apache/spark into SPARK-5854-Personalized-PageRank 76773f6 [Dan McClary] Merge branch 'master' of https://github.com/apache/spark into SPARK-5854-Personalized-PageRank 44ada8e [Dan McClary] updated tolerance on chain PPR 1ffed95 [Dan McClary] updated tolerance on chain PPR b67ac69 [Dan McClary] updated tolerance on chain PPR a560942 [Dan McClary] rolled PPR into pregel code for PageRank 6dc2c29 [Dan McClary] initial implementation of personalized page rank --- .../org/apache/spark/graphx/GraphOps.scala | 25 +++++ .../apache/spark/graphx/lib/PageRank.scala | 93 +++++++++++++++++-- .../spark/graphx/lib/PageRankSuite.scala | 47 ++++++++++ 3 files changed, 159 insertions(+), 6 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 86f611d55aa8a..7edd627b20918 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -372,6 +372,31 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali PageRank.runUntilConvergence(graph, tol, resetProb) } + + /** + * Run personalized PageRank for a given vertex, such that all random walks + * are started relative to the source node. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] + */ + def personalizedPageRank(src: VertexId, tol: Double, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) + } + + /** + * Run Personalized PageRank for a fixed number of iterations with + * with all iterations originating at the source node + * returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] + */ + def staticPersonalizedPageRank(src: VertexId, numIter: Int, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) + } + /** * Run PageRank for a fixed number of iterations returning a graph with vertex attributes * containing the PageRank and edge attributes the normalized edge weight. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 042e366a29f58..bc974b2f04e70 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag +import scala.language.postfixOps import org.apache.spark.Logging import org.apache.spark.graphx._ @@ -60,6 +61,7 @@ import org.apache.spark.graphx._ */ object PageRank extends Logging { + /** * Run PageRank for a fixed number of iterations returning a graph * with vertex attributes containing the PageRank and edge @@ -74,10 +76,33 @@ object PageRank extends Logging { * * @return the graph containing with each vertex containing the PageRank and each edge * containing the normalized weight. + */ + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, + resetProb: Double = 0.15): Graph[Double, Double] = + { + runWithOptions(graph, numIter, resetProb) + } + + /** + * Run PageRank for a fixed number of iterations returning a graph + * with vertex attributes containing the PageRank and edge + * attributes the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param numIter the number of iterations of PageRank to run + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. * */ - def run[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = + def runWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the PageRank graph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. @@ -89,6 +114,10 @@ object PageRank extends Logging { // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) + val personalized = srcId isDefined + val src: VertexId = srcId.getOrElse(-1L) + def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + var iteration = 0 var prevRankGraph: Graph[Double, Double] = null while (iteration < numIter) { @@ -103,8 +132,14 @@ object PageRank extends Logging { // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the // edge partitions. prevRankGraph = rankGraph + val rPrb = if (personalized) { + (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + } else { + (src: VertexId, id: VertexId) => resetProb + } + rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -133,7 +168,29 @@ object PageRank extends Logging { * containing the normalized weight. */ def runUntilConvergence[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + { + runUntilConvergenceWithOptions(graph, tol, resetProb) + } + + /** + * Run a dynamic version of PageRank returning a graph with vertex attributes containing the + * PageRank and edge attributes containing the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param tol the tolerance allowed at convergence (smaller => more accurate). + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. + */ + def runUntilConvergenceWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. @@ -148,6 +205,10 @@ object PageRank extends Logging { .mapVertices( (id, attr) => (0.0, 0.0) ) .cache() + val personalized = srcId.isDefined + val src: VertexId = srcId.getOrElse(-1L) + + // Define the three functions needed to implement PageRank in the GraphX // version of Pregel def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { @@ -156,7 +217,18 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { + def personalizedVertexProgram(id: VertexId, attr: (Double, Double), + msgSum: Double): (Double, Double) = { + val (oldPR, lastDelta) = attr + var teleport = oldPR + val delta = if (src==id) 1.0 else 0.0 + teleport = oldPR*delta + + val newPR = teleport + (1.0 - resetProb) * msgSum + (newPR, newPR - oldPR) + } + + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { @@ -170,8 +242,17 @@ object PageRank extends Logging { val initialMessage = resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. + val vp = if (personalized) { + (id: VertexId, attr: (Double, Double),msgSum: Double) => + personalizedVertexProgram(id, attr, msgSum) + } else { + (id: VertexId, attr: (Double, Double), msgSum: Double) => + vertexProgram(id, attr, msgSum) + } + Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 95804b07b1db0..3f3c9dfd7b3dd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -92,6 +92,36 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank + test("Star PersonalPageRank") { + withSpark { sc => + val nVertices = 100 + val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() + val resetProb = 0.15 + val errorTol = 1.0e-5 + + val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + .vertices.cache() + + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + if (pr1 != pr2) 1 else 0 + }.map { case (vid, test) => test }.sum + assert(notMatching === 0) + + val staticErrors = staticRanks2.map { case (vid, pr) => + val correct = (vid > 0 && pr == resetProb) || + (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * + (nVertices - 1)) )) < 1.0E-5) + if (!correct) 1 else 0 + } + assert(staticErrors.sum === 0) + + val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + } + } // end of test Star PageRank + test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -128,4 +158,21 @@ class PageRankSuite extends FunSuite with LocalSparkContext { assert(compareRanks(staticRanks, dynamicRanks) < errorTol) } } + + test("Chain PersonalizedPageRank") { + withSpark { sc => + val chain1 = (0 until 9).map(x => (x, x + 1) ) + val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 10 + val errorTol = 1.0e-1 + + val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices + val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + } + } } From 1262e310cd294c8fd936c55c3281ed855824ea27 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 1 May 2015 19:57:37 +0100 Subject: [PATCH 02/91] [SPARK-6846] [WEBUI] [HOTFIX] return to GET for kill link in UI since YARN AM won't proxy POST Partial undoing of SPARK-6846; YARN AM proxy won't forward POSTs, so go back to GET for kill links in Spark UI. Standalone UIs are not affected. Author: Sean Owen Closes #5837 from srowen/SPARK-6846.2 and squashes the following commits: c17c386 [Sean Owen] Partial undoing of SPARK-6846; YARN AM proxy won't forward POSTs, so go back to GET for kill links in Spark UI. Standalone UIs are not affected. --- .../spark/deploy/master/ui/MasterWebUI.scala | 4 ++-- .../scala/org/apache/spark/ui/JettyUtils.scala | 16 +++++++++------- .../main/scala/org/apache/spark/ui/SparkUI.scala | 4 +++- .../org/apache/spark/ui/jobs/StageTable.scala | 7 ++++++- .../org/apache/spark/ui/UISeleniumSuite.scala | 5 +++-- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index aad9c87bdb987..dea0a65eeeaa6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -44,9 +44,9 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler( - "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( - "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index a091ca650c60c..dfd6fdb5e9993 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -115,19 +115,21 @@ private[spark] object JettyUtils extends Logging { destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", - httpMethod: String = "GET"): ServletContextHandler = { + httpMethods: Set[String] = Set("GET")): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { - httpMethod match { - case "GET" => doRequest(request, response) - case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + if (httpMethods.contains("GET")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } } override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { - httpMethod match { - case "POST" => doRequest(request, response) - case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + if (httpMethods.contains("POST")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } } private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 580ab8b1325f8..06fce86bd38d2 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,8 +55,10 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + httpMethods = Set("GET", "POST"))) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 6d8c7e1fda8d8..a33243d4252bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -76,15 +76,20 @@ private[ui] class StageTableBase( val basePathUri = UIUtils.prependBaseUri(basePath) val killLink = if (killEnabled) { - val killLinkUri = s"$basePathUri/stages/stage/kill/" val confirm = s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + "{ this.parentNode.submit(); return true; } else { return false; }" + // SPARK-6846 this should be POST-only but YARN AM won't proxy POST + /* + val killLinkUri = s"$basePathUri/stages/stage/kill/"
(kill)
+ */ + val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}&terminate=true" + (kill) } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index eb9db550fd74c..d53d7f3ba5ae7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -350,7 +350,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } - test("kill stage is POST only") { + test("kill stage POST/GET response is correct") { def getResponseCode(url: URL, method: String): Int = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) @@ -365,7 +365,8 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") - getResponseCode(url, "GET") should be (405) + // SPARK-6846: should be POST only but YARN AM doesn't proxy POST + getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) } } From 16860327286bc08b4e2283d51b4c8fe024ba5006 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 May 2015 11:59:12 -0700 Subject: [PATCH 03/91] [SPARK-7183] [NETWORK] Fix memory leak of TransportRequestHandler.streamIds JIRA: https://issues.apache.org/jira/browse/SPARK-7183 Author: Liang-Chi Hsieh Closes #5743 from viirya/fix_requesthandler_memory_leak and squashes the following commits: cf2c086 [Liang-Chi Hsieh] For comments. 97e205c [Liang-Chi Hsieh] Remove unused import. d35f19a [Liang-Chi Hsieh] For comments. f9a0c37 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_requesthandler_memory_leak 45908b7 [Liang-Chi Hsieh] for style. 17f020f [Liang-Chi Hsieh] Remove unused import. 37a4b6c [Liang-Chi Hsieh] Remove streamIds from TransportRequestHandler. 3b3f38a [Liang-Chi Hsieh] Fix memory leak of TransportRequestHandler.streamIds. --- .../server/OneForOneStreamManager.java | 35 ++++++++++++++----- .../spark/network/server/StreamManager.java | 19 +++++++--- .../server/TransportRequestHandler.java | 14 ++------ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index a6d390e13f396..c95e64e8e2cda 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,14 +20,18 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import com.google.common.base.Preconditions; + /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; - private final Map streams; + private final ConcurrentHashMap streams; /** State of a single stream. */ private static class StreamState { final Iterator buffers; + // The channel associated to the stream + Channel associatedChannel = null; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator buffers) { - this.buffers = buffers; + this.buffers = Preconditions.checkNotNull(buffers); } } @@ -58,6 +65,13 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap(); } + @Override + public void registerChannel(Channel channel, long streamId) { + if (streams.containsKey(streamId)) { + streams.get(streamId).associatedChannel = channel; + } + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } @Override - public void connectionTerminated(long streamId) { - // Release all remaining buffers. - StreamState state = streams.remove(streamId); - if (state != null && state.buffers != null) { - while (state.buffers.hasNext()) { - state.buffers.next().release(); + public void connectionTerminated(Channel channel) { + // Close all streams which have been associated with the channel. + for (Map.Entry entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.associatedChannel == channel) { + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 5a9a14a180c10..929f789bf9d24 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import io.netty.channel.Channel; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -44,9 +46,18 @@ public abstract class StreamManager { public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** - * Indicates that the TCP connection that was tied to the given stream has been terminated. After - * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned - * up. + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. The getChunk() method will be called serially on this connection and once the + * connection is closed, the stream will never be used again, enabling cleanup. + * + * This must be called before the first getChunk() on the stream, but it may be invoked multiple + * times with the same channel and stream id. + */ + public void registerChannel(Channel channel, long streamId) { } + + /** + * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not + * to read from the associated streams again, so any state can be cleaned up. */ - public void connectionTerminated(long streamId) { } + public void connectionTerminated(Channel channel) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1580180cc17e9..e5159ab56d0d4 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,10 +17,7 @@ package org.apache.spark.network.server; -import java.util.Set; - import com.google.common.base.Throwables; -import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; - /** List of all stream ids that have been read on this handler, used for cleanup. */ - private final Set streamIds; - public TransportRequestHandler( Channel channel, TransportClient reverseClient, @@ -73,7 +67,6 @@ public TransportRequestHandler( this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); - this.streamIds = Sets.newHashSet(); } @Override @@ -82,10 +75,7 @@ public void exceptionCaught(Throwable cause) { @Override public void channelUnregistered() { - // Inform the StreamManager that these streams will no longer be read from. - for (long streamId : streamIds) { - streamManager.connectionTerminated(streamId); - } + streamManager.connectionTerminated(channel); rpcHandler.connectionTerminated(reverseClient); } @@ -102,12 +92,12 @@ public void handle(RequestMessage request) { private void processFetchRequest(final ChunkFetchRequest req) { final String client = NettyUtils.getRemoteAddress(channel); - streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); ManagedBuffer buf; try { + streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format( From 37537760d19eab878a5e1a48641cc49e6cb4b989 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 May 2015 12:49:02 -0700 Subject: [PATCH 04/91] [SPARK-7274] [SQL] Create Column expression for array/struct creation. Author: Reynold Xin Closes #5802 from rxin/SPARK-7274 and squashes the following commits: 19aecaa [Reynold Xin] Fixed unicode tests. bfc1538 [Reynold Xin] Export all Python functions. 2517b8c [Reynold Xin] Code review. 23da335 [Reynold Xin] Fixed Python bug. 132002e [Reynold Xin] Fixed tests. 56fce26 [Reynold Xin] Added Python support. b0d591a [Reynold Xin] Fixed debug error. 86926a6 [Reynold Xin] Added test suite. 7dbb9ab [Reynold Xin] Ok one more. 470e2f5 [Reynold Xin] One more MLlib ... e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation. --- .../spark/ml/feature/VectorAssembler.scala | 13 ++- python/pyspark/sql/functions.py | 80 +++++++++++++----- .../catalyst/expressions/BoundAttribute.scala | 10 ++- .../org/apache/spark/sql/functions.scala | 41 ++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 84 +++++++++++++++++++ 5 files changed, 199 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7b2a451ca5ee5..5e781a326d98c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val args = inputColNames.map { c => schema(c).dataType match { - case DoubleType => UnresolvedAttribute(c) - case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType | BooleanType => - Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 241f82175726f..641220a264295 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,13 +24,20 @@ from itertools import imap as map from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.dataframe import Column, _to_java_column, _to_seq -__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] +__all__ = [ + 'approxCountDistinct', + 'countDistinct', + 'monotonicallyIncreasingId', + 'rand', + 'randn', + 'sparkPartitionId', + 'udf'] def _create_function(name, doc=""): @@ -74,27 +81,21 @@ def _(col): __all__.sort() -def rand(seed=None): - """ - Generate a random column with i.i.d. samples from U[0.0, 1.0]. - """ - sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.rand(seed) - else: - jc = sc._jvm.functions.rand() - return Column(jc) +def array(*cols): + """Creates a new array column. + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. -def randn(seed=None): - """ - Generate a column with i.i.d. samples from the standard normal distribution. + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] """ sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.randn(seed) - else: - jc = sc._jvm.functions.randn() + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -146,6 +147,28 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def rand(seed=None): + """Generates a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """Generates a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -158,6 +181,25 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +def struct(*cols): + """Creates a new struct column. + + :param cols: list of column names (string) or list of :class:`Column` expressions + that are named or aliased. + + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 2225621dbaabd..c6217f07c452d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends Expression with trees.LeafNode[Expression] { + extends NamedExpression with trees.LeafNode[Expression] { type EvaluatedType = Any override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) + + override def name: String = s"i[$ordinal]" + + override def toAttribute: Attribute = throw new UnsupportedOperationException + + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + + override def exprId: ExprId = throw new UnsupportedOperationException } object BindReferences extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 242e64d3ff881..7e283393d0563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -283,6 +283,23 @@ object functions { */ def abs(e: Column): Column = Abs(e.expr) + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + @scala.annotation.varargs + def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + def array(colName: String, colNames: String*): Column = { + array((colName +: colNames).map(col) : _*) + } + /** * Returns the first column that is not null. * {{{ @@ -390,6 +407,28 @@ object functions { */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Creates a new struct column. The input column must be a column in a [[DataFrame]], or + * a derived column expression that is named (i.e. aliased). + * + * @group normal_funcs + */ + @scala.annotation.varargs + def struct(cols: Column*): Column = { + require(cols.forall(_.expr.isInstanceOf[NamedExpression]), + s"struct input columns must all be named or aliased ($cols)") + CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + } + + /** + * Creates a new struct column that composes multiple input columns. + * + * @group normal_funcs + */ + def struct(colName: String, colNames: String*): Column = { + struct((colName +: colNames).map(col) : _*) + } + /** * Converts a string expression to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala new file mode 100644 index 0000000000000..ca03713ef4658 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ + +/** + * Test suite for functions in [[org.apache.spark.sql.functions]]. + */ +class DataFrameFunctionsSuite extends QueryTest { + + test("array with column name") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array("a", "b")).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 1)) + } + + test("array with column expression") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array(col("a"), col("b") + col("b"))).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + } + + // Turn this on once we add a rule to the analyzer to throw a friendly exception + ignore("array: throw exception if putting columns of different types into an array") { + val df = Seq((0, "str")).toDF("a", "b") + intercept[AnalysisException] { + df.select(array("a", "b")) + } + } + + test("struct with column name") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct("a", "b")).first() + + val expectedType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(1, "str")) + } + + test("struct with column expression") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("c", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct: must use named column expression") { + intercept[IllegalArgumentException] { + struct(col("a") * 2) + } + } +} From 58d6584d349d5208a994a074b4cfa8a6ec4d1665 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 1 May 2015 13:01:14 -0700 Subject: [PATCH 05/91] Revert "[SPARK-7287] enabled fixed test" This reverts commit 7cf1eb79b1fa290aa1d867a8a1eaaea86d6b2239. --- .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 35382be7e0ef1..8360b94599547 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -335,7 +335,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties runSparkSubmit(args) } - test("includes jars passed in through --packages") { + ignore("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") From c6d9a429421561508e8adbb4892954381bc33a90 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 1 May 2015 13:01:43 -0700 Subject: [PATCH 06/91] Revert "[SPARK-7224] added mock repository generator for --packages tests" This reverts commit 7dacc08ab36188991a001df23880167433844767. --- .../scala/org/apache/spark/TestUtils.scala | 27 +- .../org/apache/spark/deploy/SparkSubmit.scala | 129 ++++----- .../apache/spark/deploy/IvyTestUtils.scala | 262 ------------------ .../spark/deploy/SparkSubmitSuite.scala | 25 +- .../spark/deploy/SparkSubmitUtilsSuite.scala | 57 ++-- 5 files changed, 97 insertions(+), 403 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fe6320b504e15..398ca41e16151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -105,18 +105,23 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private[spark] class JavaSourceFromString(val name: String, val code: String) + private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the source file. Class file will be placed in destDir. */ + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - sourceFile: JavaSourceFromString, - classpathUrls: Seq[URL]): File = { + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { val compiler = ToolProvider.getSystemJavaCompiler + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -139,18 +144,4 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } - - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass( - className: String, - destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") - createCompiledClass(className, destDir, sourceFile, classpathUrls) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0d149e703aff2..b8ae4af18d1d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL -import java.nio.file.{Path => JavaPath} import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -709,9 +708,7 @@ private[deploy] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { - override def toString: String = s"$groupId:$artifactId:$version" - } + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -734,10 +731,6 @@ private[deploy] object SparkSubmitUtils { } } - /** Path of the local Maven cache. */ - private[spark] def m2Path: JavaPath = new File(System.getProperty("user.home"), - ".m2" + File.separator + "repository" + File.separator).toPath - /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -751,7 +744,8 @@ private[deploy] object SparkSubmitUtils { val localM2 = new IBiblioResolver localM2.setM2compatible(true) - localM2.setRoot(m2Path.toUri.toString) + val m2Path = ".m2" + File.separator + "repository" + File.separator + localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) @@ -876,72 +870,69 @@ private[deploy] object SparkSubmitUtils { "" } else { val sysOut = System.out - try { - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") - } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") - } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") } else { - resolveOptions.setDownload(true) + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) - } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) - } finally { - System.setOut(sysOut) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + System.setOut(sysOut) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala deleted file mode 100644 index 529f91e8eaf9e..0000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.io.{File, FileInputStream, FileOutputStream} -import java.nio.file.{Files, Path} -import java.util.jar.{JarEntry, JarOutputStream} - -import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} - -import com.google.common.io.ByteStreams - -import org.apache.commons.io.FileUtils - -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate - -private[deploy] object IvyTestUtils { - - /** - * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` - * or `pom`. - */ - private def pathFromCoordinate( - artifact: MavenCoordinate, - prefix: Path, - ext: String, - useIvyLayout: Boolean): Path = { - val groupDirs = artifact.groupId.replace(".", File.separator) - val artifactDirs = artifact.artifactId - val artifactPath = - if (!useIvyLayout) { - Seq(groupDirs, artifactDirs, artifact.version).mkString(File.separator) - } else { - Seq(groupDirs, artifactDirs, artifact.version, ext + "s").mkString(File.separator) - } - new File(prefix.toFile, artifactPath).toPath - } - - private def artifactName(artifact: MavenCoordinate, ext: String = ".jar"): String = { - s"${artifact.artifactId}-${artifact.version}$ext" - } - - /** Write the contents to a file to the supplied directory. */ - private def writeFile(dir: File, fileName: String, contents: String): File = { - val outputFile = new File(dir, fileName) - val outputStream = new FileOutputStream(outputFile) - outputStream.write(contents.toCharArray.map(_.toByte)) - outputStream.close() - outputFile - } - - /** Create an example Python file. */ - private def createPythonFile(dir: File): File = { - val contents = - """def myfunc(x): - | return x + 1 - """.stripMargin - writeFile(dir, "mylib.py", contents) - } - - /** Create a simple testable Class. */ - private def createJavaClass(dir: File, className: String, packageName: String): File = { - val contents = - s"""package $packageName; - | - |import java.lang.Integer; - | - |class $className implements java.io.Serializable { - | - | public $className() {} - | - | public Integer myFunc(Integer x) { - | return x + 1; - | } - |} - """.stripMargin - val sourceFile = - new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) - createCompiledClass(className, dir, sourceFile, Seq.empty) - } - - /** Helper method to write artifact information in the pom. */ - private def pomArtifactWriter(artifact: MavenCoordinate, tabCount: Int = 1): String = { - var result = "\n" + " " * tabCount + s"${artifact.groupId}" - result += "\n" + " " * tabCount + s"${artifact.artifactId}" - result += "\n" + " " * tabCount + s"${artifact.version}" - result - } - - /** Create a pom file for this artifact. */ - private def createPom( - dir: File, - artifact: MavenCoordinate, - dependencies: Option[Seq[MavenCoordinate]]): File = { - var content = """ - | - | - | 4.0.0 - """.stripMargin.trim - content += pomArtifactWriter(artifact) - content += dependencies.map { deps => - val inside = deps.map { dep => - "\t" + pomArtifactWriter(dep, 3) + "\n\t" - }.mkString("\n") - "\n \n" + inside + "\n " - }.getOrElse("") - content += "\n" - writeFile(dir, artifactName(artifact, ".pom"), content.trim) - } - - /** Create the jar for the given maven coordinate, using the supplied files. */ - private def packJar( - dir: File, - artifact: MavenCoordinate, - files: Seq[(String, File)]): File = { - val jarFile = new File(dir, artifactName(artifact)) - val jarFileStream = new FileOutputStream(jarFile) - val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) - - for (file <- files) { - val jarEntry = new JarEntry(file._1) - jarStream.putNextEntry(jarEntry) - - val in = new FileInputStream(file._2) - ByteStreams.copy(in, jarStream) - in.close() - } - jarStream.close() - jarFileStream.close() - - jarFile - } - - /** - * Creates a jar and pom file, mocking a Maven repository. The root path can be supplied with - * `tempDir`, dependencies can be created into the same repo, and python files can also be packed - * inside the jar. - * - * @param artifact The maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param tempDir The root folder of the repository - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository - */ - private def createLocalRepository( - artifact: MavenCoordinate, - dependencies: Option[Seq[MavenCoordinate]] = None, - tempDir: Option[Path] = None, - useIvyLayout: Boolean = false, - withPython: Boolean = false): Path = { - // Where the root of the repository exists, and what Ivy will search in - val tempPath = tempDir.getOrElse(Files.createTempDirectory(null)) - // Create directory if it doesn't exist - Files.createDirectories(tempPath) - // Where to create temporary class files and such - val root = Files.createTempDirectory(tempPath, null).toFile - try { - val jarPath = pathFromCoordinate(artifact, tempPath, "jar", useIvyLayout) - Files.createDirectories(jarPath) - val className = "MyLib" - - val javaClass = createJavaClass(root, className, artifact.groupId) - // A tuple of files representation in the jar, and the file - val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) - val allFiles = - if (withPython) { - val pythonFile = createPythonFile(root) - Seq(javaFile, (pythonFile.getName, pythonFile)) - } else { - Seq(javaFile) - } - val jarFile = packJar(jarPath.toFile, artifact, allFiles) - assert(jarFile.exists(), "Problem creating Jar file") - val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) - Files.createDirectories(pomPath) - val pomFile = createPom(pomPath.toFile, artifact, dependencies) - assert(pomFile.exists(), "Problem creating Pom file") - } finally { - FileUtils.deleteDirectory(root) - } - tempPath - } - - /** - * Creates a suite of jars and poms, with or without dependencies, mocking a maven repository. - * @param artifact The main maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param rootDir The root folder of the repository (like `~/.m2/repositories`) - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository. Will be `rootDir` if supplied. - */ - private[deploy] def createLocalRepositoryForTests( - artifact: MavenCoordinate, - dependencies: Option[String], - rootDir: Option[Path], - useIvyLayout: Boolean = false, - withPython: Boolean = false): Path = { - val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) - val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) - deps.foreach { seq => seq.foreach { dep => - createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) - }} - mainRepo - } - - /** - * Creates a repository for a test, and cleans it up afterwards. - * - * @param artifact The main maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param rootDir The root folder of the repository (like `~/.m2/repositories`) - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository. Will be `rootDir` if supplied. - */ - private[deploy] def withRepository( - artifact: MavenCoordinate, - dependencies: Option[String], - rootDir: Option[Path], - useIvyLayout: Boolean = false, - withPython: Boolean = false)(f: String => Unit): Unit = { - val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, - withPython) - try { - f(repo.toUri.toString) - } finally { - // Clean up - if (repo.toString.contains(".m2") || repo.toString.contains(".ivy2")) { - FileUtils.deleteDirectory(new File(repo.toFile, - artifact.groupId.replace(".", File.separator) + File.separator + artifact.artifactId)) - dependencies.map(SparkSubmitUtils.extractMavenCoordinates).foreach { seq => - seq.foreach { dep => - FileUtils.deleteDirectory(new File(repo.toFile, - dep.artifactId.replace(".", File.separator))) - } - } - } else { - FileUtils.deleteDirectory(repo.toFile) - } - } - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 8360b94599547..029a1156fda6b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -30,7 +30,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch @@ -337,20 +336,16 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties ignore("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) - val main = MavenCoordinate("my.great.lib", "mylib", "0.1") - val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => - val args = Seq( - "--class", JarCreationTest.getClass.getName.stripSuffix("$"), - "--name", "testApp", - "--master", "local-cluster[2,1,512]", - "--packages", Seq(main, dep).mkString(","), - "--repositories", repo, - "--conf", "spark.ui.enabled=false", - unusedJar.toString, - "my.great.lib.MyLib", "my.great.dep.MyLib") - runSparkSubmit(args) - } + val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", packagesString, + "--conf", "spark.ui.enabled=false", + unusedJar.toString, + "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") + runSparkSubmit(args) } test("resolves command line argument paths correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index cc79ee7ea20b4..1b2b699cb11e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} +import org.apache.ivy.core.settings.IvySettings + import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.ivy.core.module.descriptor.MDArtifact -import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate - class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { @@ -89,7 +89,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" + val ivyPath = "dummy/ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) @@ -98,38 +98,17 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { assert(index >= 0) jPaths = jPaths.substring(index + ivyPath.length) } - val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") - IvyTestUtils.withRepository(main, None, None) { repo => - // end to end - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") - } + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") } - test("search for artifact at local repositories") { - val main = new MavenCoordinate("my.awesome.lib", "mylib", "0.1") - // Local M2 repository - IvyTestUtils.withRepository(main, None, Some(SparkSubmitUtils.m2Path)) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - } - // Local Ivy Repository - val settings = new IvySettings - val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(ivyLocal.toPath), true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - } - // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal.toPath), true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") - } + test("search for artifact at other repositories") { + val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", + Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) + assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + + "if package still exists. If it has been removed, replace the example in this test.") } test("dependency not found throws RuntimeException") { @@ -148,11 +127,11 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) assert(path === "", "should return empty path") - val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") - IvyTestUtils.withRepository(main, None, None) { repo => - val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, - Some(repo), None, true) - assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + // Should not exclude the following dependency. Will throw an error, because it doesn't exist, + // but the fact that it is checking means that it wasn't excluded. + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates(coordinates + + ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) } } } From f53a48827ef024f91b292132075e5598c9cb94bb Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Fri, 1 May 2015 21:14:16 +0100 Subject: [PATCH 07/91] [SPARK-7213] [YARN] Check for read permissions before copying a Hadoop config file Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #5760 from nishkamravi2/master_nravi and squashes the following commits: eaa13b5 [nishkamravi2] Update Client.scala 981afd2 [Nishkam Ravi] Check for read permission before initiating copy 1b81383 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 0f1abd0 [nishkamravi2] Update Utils.scala 474e3bf [nishkamravi2] Update DiskBlockManager.scala 97c383e [nishkamravi2] Update Utils.scala 8691e0c [Nishkam Ravi] Add a try/catch block around Utils.removeShutdownHook 2be1e76 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 1c13b79 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4abcf7307a388..b945395f24ea6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -371,9 +371,11 @@ private[spark] class Client( try { hadoopConfStream.setLevel(0) hadoopConfFiles.foreach { case (name, file) => - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() + if (file.canRead()) { + hadoopConfStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, hadoopConfStream) + hadoopConfStream.closeEntry() + } } } finally { hadoopConfStream.close() From 7b5dd3e3c0030087eea5a8224789352c03717c1d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 1 May 2015 21:20:46 +0100 Subject: [PATCH 08/91] [SPARK-7281] [YARN] Add option to set AM's lib path in client mode. Author: Marcelo Vanzin Closes #5813 from vanzin/SPARK-7281 and squashes the following commits: 1cb6f42 [Marcelo Vanzin] [SPARK-7281] [yarn] Add option to set AM's lib path in client mode. --- docs/running-on-yarn.md | 7 +++++++ .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0968fc5ad632b..b6701b64c2925 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -189,6 +189,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes In cluster mode, use spark.driver.extraJavaOptions instead. + + spark.yarn.am.extraLibraryPath + (none) + + Set a special library path to use when launching the application master in client mode. + + spark.yarn.maxAppAttempts yarn.resourcemanager.am.max-attempts in YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b945395f24ea6..31ab6b491ec2a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -544,6 +544,10 @@ private[spark] class Client( } javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } + + sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => + prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + } } // For log4j configuration to reference From 4dc8d74491b101a794cf8d386d8c5ebc6019b75f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 13:29:17 -0700 Subject: [PATCH 09/91] [SPARK-7240][SQL] Single pass covariance calculation for dataframes Added the calculation of covariance between two columns to DataFrames. cc mengxr rxin Author: Burak Yavuz Closes #5825 from brkyvz/df-cov and squashes the following commits: cb18046 [Burak Yavuz] changed to sample covariance f2e862b [Burak Yavuz] fixed failed test 51e39b8 [Burak Yavuz] moved implementation 0c6a759 [Burak Yavuz] addressed math comments 8456eca [Burak Yavuz] fix pyStyle3 aa2ad29 [Burak Yavuz] fix pyStyle2 4e97a50 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-cov e3b0b85 [Burak Yavuz] addressed comments v0.1 a7115f1 [Burak Yavuz] fix python style 7dc6dbc [Burak Yavuz] reorder imports 408cb77 [Burak Yavuz] initial commit --- python/pyspark/sql/__init__.py | 4 +- python/pyspark/sql/dataframe.py | 36 ++++++++- python/pyspark/sql/tests.py | 5 ++ .../spark/sql/DataFrameStatFunctions.scala | 12 ++- .../sql/execution/stat/StatFunctions.scala | 80 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 7 ++ .../apache/spark/sql/DataFrameStatSuite.scala | 18 ++++- 7 files changed, 157 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 6d54b9e49ed10..b60b991dd4d8b 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -54,7 +54,9 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions +from pyspark.sql.dataframe import DataFrameStatFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions' ] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc990a56..1f08c2df9305b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -34,7 +34,8 @@ from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] class DataFrame(object): @@ -93,6 +94,12 @@ def na(self): """ return DataFrameNaFunctions(self) + @property + def stat(self): + """Returns a :class:`DataFrameStatFunctions` for statistic functions. + """ + return DataFrameStatFunctions(self) + @ignore_unicode_prefix def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -868,6 +875,20 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def cov(self, col1, col2): + """ + Calculate the sample covariance for the given columns, specified by their names, as a + double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return self._jdf.stat().cov(col1, col2) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1311,6 +1332,19 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ +class DataFrameStatFunctions(object): + """Functionality for statistic functions with :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def cov(self, col1, col2): + return self.df.cov(col1, col2) + + cov.__doc__ = DataFrame.cov.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5ea2346..44c8b6a1aac13 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,11 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_cov(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + cov = df.stat.cov("a", "b") + self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 42e5cbc05e1e0..23652aeb7c7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.stat.FrequentItems +import org.apache.spark.sql.execution.stat._ /** * :: Experimental :: @@ -65,4 +65,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: List[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Calculate the sample covariance of two numerical columns of a DataFrame. + * @param col1 the name of the first column + * @param col2 the name of the second column + * @return the covariance of the two columns. + */ + def cov(col1: String, col2: String): Double = { + StatFunctions.calculateCov(df, Seq(col1, col2)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala new file mode 100644 index 0000000000000..d4a94c24d9866 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.stat + +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.types.{DoubleType, NumericType} + +private[sql] object StatFunctions { + + /** Helper class to simplify tracking and merging counts. */ + private class CovarianceCounter extends Serializable { + var xAvg = 0.0 + var yAvg = 0.0 + var Ck = 0.0 + var count = 0L + // add an example to the calculation + def add(x: Double, y: Double): this.type = { + val oldX = xAvg + count += 1 + xAvg += (x - xAvg) / count + yAvg += (y - yAvg) / count + Ck += (y - yAvg) * (x - oldX) + this + } + // merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance + def merge(other: CovarianceCounter): this.type = { + val totalCount = count + other.count + Ck += other.Ck + + (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + count = totalCount + this + } + // return the sample covariance for the observed examples + def cov: Double = Ck / (count - 1) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @param cols the column names + * @return the covariance of the two columns. + */ + private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + require(cols.length == 2, "Currently cov supports calculating the covariance " + + "between two columns.") + cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => + require(data.nonEmpty, s"Couldn't find column with name $name") + require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + + s"with dataType ${data.get.dataType} not supported.") + } + val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) + val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( + seqOp = (counter, row) => { + counter.add(row.getDouble(0), row.getDouble(1)) + }, + combOp = (baseCounter, other) => { + baseCounter.merge(other) + }) + counts.cov + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ebe96e649d940..96fe66d0b84a6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -186,4 +186,11 @@ public void testFrequentItems() { DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } + + @Test + public void testCovariance() { + DataFrame df = context.table("testData2"); + Double result = df.stat().cov("a", "b"); + Assert.assertTrue(Math.abs(result) < 1e-6); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bb1d29c71d23b..4f5a2ff696789 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { + import TestData._ val sqlCtx = TestSQLContext - + def toLetter(i: Int): String = (i + 97).toChar.toString + test("Frequent Items") { - def toLetter(i: Int): String = (i + 96).toChar.toString val rows = Array.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } @@ -44,4 +45,17 @@ class DataFrameStatSuite extends FunSuite { items2.getSeq[Double](0) should contain (-1.0) } + + test("covariance") { + val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) + val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + + val results = df.stat.cov("singles", "doubles") + assert(math.abs(results - 55.0 / 3) < 1e-6) + intercept[IllegalArgumentException] { + df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes + } + val decimalRes = decimalData.stat.cov("a", "b") + assert(math.abs(decimalRes) < 1e-6) + } } From b1f4ca82d170935d15f1fe6beb9af0743b4d81cd Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Fri, 1 May 2015 15:32:09 -0500 Subject: [PATCH 10/91] [SPARK-5342] [YARN] Allow long running Spark apps to run on secure YARN/HDFS Take 2. Does the same thing as #4688, but fixes Hadoop-1 build. Author: Hari Shreedharan Closes #5823 from harishreedharan/kerberos-longrunning and squashes the following commits: 3c86bba [Hari Shreedharan] Import fixes. Import postfixOps explicitly. 4d04301 [Hari Shreedharan] Minor formatting fixes. b5e7a72 [Hari Shreedharan] Remove reflection, use a method in SparkHadoopUtil to update the token renewer. 7bff6e9 [Hari Shreedharan] Make sure all required classes are present in the jar. Fix import order. e851f70 [Hari Shreedharan] Move the ExecutorDelegationTokenRenewer to yarn module. Use reflection to use it. 36eb8a9 [Hari Shreedharan] Change the renewal interval config param. Fix a bunch of comments. 611923a [Hari Shreedharan] Make sure the namenodes are listed correctly for creating tokens. 09fe224 [Hari Shreedharan] Use token.renew to get token's renewal interval rather than using hdfs-site.xml 6963bbc [Hari Shreedharan] Schedule renewal in AM before starting user class. Else, a restarted AM cannot access HDFS if the user class tries to. 072659e [Hari Shreedharan] Fix build failure caused by thread factory getting moved to ThreadUtils. f041dd3 [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning 42eead4 [Hari Shreedharan] Remove RPC part. Refactor and move methods around, use renewal interval rather than max lifetime to create new tokens. ebb36f5 [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning bc083e3 [Hari Shreedharan] Overload RegisteredExecutor to send tokens. Minor doc updates. 7b19643 [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning 8a4f268 [Hari Shreedharan] Added docs in the security guide. Changed some code to ensure that the renewer objects are created only if required. e800c8b [Hari Shreedharan] Restore original RegisteredExecutor message, and send new tokens via NewTokens message. 0e9507e [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning 7f1bc58 [Hari Shreedharan] Minor fixes, cleanup. bcd11f9 [Hari Shreedharan] Refactor AM and Executor token update code into separate classes, also send tokens via akka on executor startup. f74303c [Hari Shreedharan] Move the new logic into specialized classes. Add cleanup for old credentials files. 2f9975c [Hari Shreedharan] Ensure new tokens are written out immediately on AM restart. Also, pikc up the latest suffix from HDFS if the AM is restarted. 61b2b27 [Hari Shreedharan] Account for AM restarts by making sure lastSuffix is read from the files on HDFS. 62c45ce [Hari Shreedharan] Relogin from keytab periodically. fa233bd [Hari Shreedharan] Adding logging, fixing minor formatting and ordering issues. 42813b4 [Hari Shreedharan] Remove utils.sh, which was re-added due to merge with master. 0de27ee [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning 55522e3 [Hari Shreedharan] Fix failure caused by Preconditions ambiguity. 9ef5f1b [Hari Shreedharan] Added explanation of how the credentials refresh works, some other minor fixes. f4fd711 [Hari Shreedharan] Fix SparkConf usage. 2debcea [Hari Shreedharan] Change the file structure for credentials files. I will push a followup patch which adds a cleanup mechanism for old credentials files. The credentials files are small and few enough for it to cause issues on HDFS. af6d5f0 [Hari Shreedharan] Cleaning up files where changes weren't required. f0f54cb [Hari Shreedharan] Be more defensive when updating the credentials file. f6954da [Hari Shreedharan] Got rid of Akka communication to renew, instead the executors check a known file's modification time to read the credentials. 5c11c3e [Hari Shreedharan] Move tests to YarnSparkHadoopUtil to fix compile issues. b4cb917 [Hari Shreedharan] Send keytab to AM via DistributedCache rather than directly via HDFS 0985b4e [Hari Shreedharan] Write tokens to HDFS and read them back when required, rather than sending them over the wire. d79b2b9 [Hari Shreedharan] Make sure correct credentials are passed to FileSystem#addDelegationTokens() 8c6928a [Hari Shreedharan] Fix issue caused by direct creation of Actor object. fb27f46 [Hari Shreedharan] Make sure principal and keytab are set before CoarseGrainedSchedulerBackend is started. Also schedule re-logins in CoarseGrainedSchedulerBackend#start() 41efde0 [Hari Shreedharan] Merge branch 'master' into kerberos-longrunning d282d7a [Hari Shreedharan] Fix ClientSuite to set YARN mode, so that the correct class is used in tests. bcfc374 [Hari Shreedharan] Fix Hadoop-1 build by adding no-op methods in SparkHadoopUtil, with impl in YarnSparkHadoopUtil. f8fe694 [Hari Shreedharan] Handle None if keytab-login is not scheduled. 2b0d745 [Hari Shreedharan] [SPARK-5342][YARN] Allow long running Spark apps to run on secure YARN/HDFS. ccba5bc [Hari Shreedharan] WIP: More changes wrt kerberos 77914dd [Hari Shreedharan] WIP: Add kerberos principal and keytab to YARN client. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 81 ++++++- .../org/apache/spark/deploy/SparkSubmit.scala | 4 + .../spark/deploy/SparkSubmitArguments.scala | 15 ++ .../CoarseGrainedExecutorBackend.scala | 9 + .../CoarseGrainedSchedulerBackend.scala | 4 +- docs/security.md | 2 + .../launcher/SparkSubmitOptionParser.java | 6 +- .../yarn/AMDelegationTokenRenewer.scala | 205 ++++++++++++++++++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +- .../org/apache/spark/deploy/yarn/Client.scala | 123 +++++++---- .../spark/deploy/yarn/ClientArguments.scala | 10 + .../yarn/ExecutorDelegationTokenUpdater.scala | 106 +++++++++ .../deploy/yarn/YarnSparkHadoopUtil.scala | 69 +++++- .../spark/deploy/yarn/ClientSuite.scala | 51 ----- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 + .../yarn/YarnSparkHadoopUtilSuite.scala | 62 +++++- 16 files changed, 657 insertions(+), 109 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index cfaebf9ea5050..b563034457a91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,12 +17,16 @@ package org.apache.spark.deploy +import java.io.{ByteArrayInputStream, DataInputStream} import java.lang.reflect.Method import java.security.PrivilegedExceptionAction +import java.util.{Arrays, Comparator} +import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.fs.FileSystem.Statistics +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -32,6 +36,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils import scala.collection.JavaConversions._ +import scala.concurrent.duration._ +import scala.language.postfixOps /** * :: DeveloperApi :: @@ -39,7 +45,8 @@ import scala.collection.JavaConversions._ */ @DeveloperApi class SparkHadoopUtil extends Logging { - val conf: Configuration = newConfiguration(new SparkConf()) + private val sparkConf = new SparkConf() + val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) /** @@ -201,6 +208,61 @@ class SparkHadoopUtil extends Logging { if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) } + /** + * Lists all the files in a directory with the specified prefix, and does not end with the + * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of + * the respective files. + */ + def listFilesSorted( + remoteFs: FileSystem, + dir: Path, + prefix: String, + exclusionSuffix: String): Array[FileStatus] = { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) + } + }) + fileStatuses + } + + /** + * How much time is remaining (in millis) from now to (fraction * renewal time for the token that + * is valid the latest)? + * This will return -ve (or 0) value if the fraction of validity has already expired. + */ + def getTimeFromNowToRenewal( + sparkConf: SparkConf, + fraction: Double, + credentials: Credentials): Long = { + val now = System.currentTimeMillis() + + val renewalInterval = + sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) + + credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .map { t => + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) + } + + + private[spark] def getSuffixForCredentialsPath(credentialsPath: Path): Int = { + val fileName = credentialsPath.getName + fileName.substring( + fileName.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) + 1).toInt + } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored /** @@ -231,6 +293,17 @@ class SparkHadoopUtil extends Logging { } } } + + /** + * Start a thread to periodically update the current user's credentials with new delegation + * tokens so that writes to HDFS do not fail. + */ + private[spark] def startExecutorDelegationTokenRenewer(conf: SparkConf) {} + + /** + * Stop the thread that does the delegation token updates. + */ + private[spark] def stopExecutorDelegationTokenRenewer() {} } object SparkHadoopUtil { @@ -251,6 +324,10 @@ object SparkHadoopUtil { } } + val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" + + val SPARK_YARN_CREDS_COUNTER_DELIM = "-" + def get: SparkHadoopUtil = { hadoop } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b8ae4af18d1d1..af38bf80e4f0b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -400,6 +400,10 @@ object SparkSubmit { OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), + // Yarn client or cluster + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + // Other options OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c621b8fc86f94..c0e4c771908b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null + var principal: String = null + var keytab: String = null // Standalone cluster mode only var supervise: Boolean = false @@ -393,6 +395,12 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PROXY_USER => proxyUser = value + case PRINCIPAL => + principal = value + + case KEYTAB => + keytab = value + case HELP => printUsageAndExit(0) @@ -506,6 +514,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the | working directory of each executor. + | --principal PRINCIPAL Principal to be used to login to KDC, while running on + | secure HDFS. + | --keytab KEYTAB The full path to the file that contains the keytab for the + | principal specified above. This keytab will be copied to + | the node running the Application Master via the Secure + | Distributed Cache, for renewing the login tickets and the + | delegation tokens periodically. """.stripMargin ) SparkSubmit.exitFn() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 79aed90b53e2f..ed159dec4f998 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,6 +20,8 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import org.apache.hadoop.conf.Configuration + import scala.collection.mutable import scala.util.{Failure, Success} @@ -168,6 +170,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } + if (driverConf.contains("spark.yarn.credentials.file")) { + logInfo("Will periodically update credentials from: " + + driverConf.get("spark.yarn.credentials.file")) + SparkHadoopUtil.get.startExecutorDelegationTokenRenewer(driverConf) + } + val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) @@ -183,6 +191,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() + SparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7352fa1fe9ebd..f107148f3b8c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -68,6 +68,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -112,6 +113,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -122,7 +124,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } else { logInfo("Registered executor: " + executorRef + " with ID " + executorId) context.reply(RegisteredExecutor) - addressToExecutorId(executorRef.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -243,6 +244,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp properties += ((key, value)) } } + // TODO (prashant) send conf instead of properties driverEndpoint = rpcEnv.setupEndpoint( CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) diff --git a/docs/security.md b/docs/security.md index c034ba12ff1fc..d4ffa60e59a33 100644 --- a/docs/security.md +++ b/docs/security.md @@ -32,6 +32,8 @@ SSL must be configured on each node and configured for each component involved i ### YARN mode The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. +For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. + ### Standalone mode The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 8526d2e7cfa3f..229000087688f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -69,8 +69,10 @@ class SparkSubmitOptionParser { // YARN-only options. protected final String ARCHIVES = "--archives"; protected final String EXECUTOR_CORES = "--executor-cores"; - protected final String QUEUE = "--queue"; + protected final String KEYTAB = "--keytab"; protected final String NUM_EXECUTORS = "--num-executors"; + protected final String PRINCIPAL = "--principal"; + protected final String QUEUE = "--queue"; /** * This is the canonical list of spark-submit options. Each entry in the array contains the @@ -96,11 +98,13 @@ class SparkSubmitOptionParser { { EXECUTOR_MEMORY }, { FILES }, { JARS }, + { KEYTAB }, { KILL_SUBMISSION }, { MASTER }, { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, { PY_FILES }, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala new file mode 100644 index 0000000000000..aaae6f9734a85 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{Executors, TimeUnit} + +import scala.language.postfixOps + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.ThreadUtils + +/* + * The following methods are primarily meant to make sure long-running apps like Spark + * Streaming apps can run without interruption while writing to secure HDFS. The + * scheduleLoginFromKeytab method is called on the driver when the + * CoarseGrainedScheduledBackend starts up. This method wakes up a thread that logs into the KDC + * once 75% of the renewal interval of the original delegation tokens used for the container + * has elapsed. It then creates new delegation tokens and writes them to HDFS in a + * pre-specified location - the prefix of which is specified in the sparkConf by + * spark.yarn.credentials.file (so the file(s) would be named c-1, c-2 etc. - each update goes + * to a new file, with a monotonically increasing suffix). After this, the credentials are + * updated once 75% of the new tokens renewal interval has elapsed. + * + * On the executor side, the updateCredentialsIfRequired method is called once 80% of the + * validity of the original tokens has elapsed. At that time the executor finds the + * credentials file with the latest timestamp and checks if it has read those credentials + * before (by keeping track of the suffix of the last file it read). If a new file has + * appeared, it will read the credentials and update the currently running UGI with it. This + * process happens again once 80% of the validity of this has expired. + */ +private[yarn] class AMDelegationTokenRenewer( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + private var lastCredentialsFileSuffix = 0 + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + private val hadoopUtil = YarnSparkHadoopUtil.get + + private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + + /** + * Schedule a login from the keytab and principal set using the --principal and --keytab + * arguments to spark-submit. This login happens only when the credentials of the current user + * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from + * SparkConf to do the login. This method is a no-op in non-YARN mode. + * + */ + private[spark] def scheduleLoginFromKeytab(): Unit = { + val principal = sparkConf.get("spark.yarn.principal") + val keytab = sparkConf.get("spark.yarn.keytab") + + /** + * Schedule re-login and creation of new tokens. If tokens have already expired, this method + * will synchronously create new ones. + */ + def scheduleRenewal(runnable: Runnable): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials + val renewalInterval = hadoopUtil.getTimeFromNowToRenewal(sparkConf, 0.75, credentials) + // Run now! + if (renewalInterval <= 0) { + logInfo("HDFS tokens have expired, creating new tokens now.") + runnable.run() + } else { + logInfo(s"Scheduling login from keytab in $renewalInterval millis.") + delegationTokenRenewer.schedule(runnable, renewalInterval, TimeUnit.MILLISECONDS) + } + } + + // This thread periodically runs on the driver to update the delegation tokens on HDFS. + val driverTokenRenewerRunnable = + new Runnable { + override def run(): Unit = { + try { + writeNewTokensToHDFS(principal, keytab) + cleanupOldFiles() + } catch { + case e: Exception => + // Log the error and try to write new tokens back in an hour + logWarning("Failed to write out new credentials to HDFS, will try again in an " + + "hour! If this happens too often tasks will fail.", e) + delegationTokenRenewer.schedule(this, 1, TimeUnit.HOURS) + return + } + scheduleRenewal(this) + } + } + // Schedule update of credentials. This handles the case of updating the tokens right now + // as well, since the renenwal interval will be 0, and the thread will get scheduled + // immediately. + scheduleRenewal(driverTokenRenewerRunnable) + } + + // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At + // least numFilesToKeep files are kept for safety + private def cleanupOldFiles(): Unit = { + import scala.concurrent.duration._ + try { + val remoteFs = FileSystem.get(hadoopConf) + val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .dropRight(numFilesToKeep) + .takeWhile(_.getModificationTime < thresholdTime) + .foreach(x => remoteFs.delete(x.getPath, true)) + } catch { + // Such errors are not fatal, so don't throw. Make sure they are logged though + case e: Exception => + logWarning("Error while attempting to cleanup old tokens. If you are seeing many such " + + "warnings there may be an issue with your HDFS cluster.", e) + } + } + + private def writeNewTokensToHDFS(principal: String, keytab: String): Unit = { + // Keytab is copied by YARN to the working directory of the AM, so full path is + // not needed. + + // HACK: + // HDFS will not issue new delegation tokens, if the Credentials object + // passed in already has tokens for that FS even if the tokens are expired (it really only + // checks if there are tokens for the service, and not if they are valid). So the only real + // way to get new tokens is to make sure a different Credentials object is used each time to + // get new tokens and then the new tokens are copied over the the current user's Credentials. + // So: + // - we login as a different user and get the UGI + // - use that UGI to get the tokens (see doAs block below) + // - copy the tokens over to the current user's credentials (this will overwrite the tokens + // in the current user's Credentials object for this FS). + // The login to KDC happens each time new tokens are required, but this is rare enough to not + // have to worry about (like once every day or so). This makes this code clearer than having + // to login and then relogin every time (the HDFS API may not relogin since we don't use this + // UGI directly for HDFS communication. + logInfo(s"Attempting to login to KDC using principal: $principal") + val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + val tempCreds = keytabLoggedInUGI.getCredentials + val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val dst = credentialsPath.getParent + keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { + // Get a copy of the credentials + override def run(): Void = { + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + null + } + }) + // Add the temp credentials back to the original ones. + UserGroupInformation.getCurrentUser.addCredentials(tempCreds) + val remoteFs = FileSystem.get(hadoopConf) + // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM + // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file + // and update the lastCredentialsFileSuffix. + if (lastCredentialsFileSuffix == 0) { + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { status => + lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) + } + } + val nextSuffix = lastCredentialsFileSuffix + 1 + val tokenPathStr = + sparkConf.get("spark.yarn.credentials.file") + + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + val tokenPath = new Path(tokenPathStr) + val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + logInfo("Writing out delegation tokens to " + tempTokenPath.toString) + val credentials = UserGroupInformation.getCurrentUser.getCredentials + credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") + remoteFs.rename(tempTokenPath, tokenPath) + logInfo("Delegation token file rename complete.") + lastCredentialsFileSuffix = nextSuffix + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 27f804782f355..e1694c1f64a9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -75,6 +75,8 @@ private[spark] class ApplicationMaster( // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) + private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -129,6 +131,15 @@ private[spark] class ApplicationMaster( // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) + // If the credentials file config is present, we must periodically renew tokens. So create + // a new AMDelegationTokenRenewer + if (sparkConf.contains("spark.yarn.credentials.file")) { + delegationTokenRenewerOption = Some(new AMDelegationTokenRenewer(sparkConf, yarnConf)) + // If a principal and keytab have been set, use that to create new credentials for executors + // periodically + delegationTokenRenewerOption.foreach(_.scheduleLoginFromKeytab()) + } + if (isClusterMode) { runDriver(securityMgr) } else { @@ -193,6 +204,7 @@ private[spark] class ApplicationMaster( logDebug("shutting down user thread") userClassThread.interrupt() } + if (!inShutdown) delegationTokenRenewerOption.foreach(_.stop()) } } } @@ -240,12 +252,12 @@ private[spark] class ApplicationMaster( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverEndpont = rpcEnv.setupEndpointRef( + val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = - rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -499,6 +511,7 @@ private[spark] class ApplicationMaster( override def onStart(): Unit = { driver.send(RegisterClusterManager(self)) + } override def receive: PartialFunction[Any, Unit] = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 31ab6b491ec2a..20ecaf092e3f8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,9 +17,11 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction +import java.util.UUID import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ @@ -36,7 +38,6 @@ import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifie import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{TokenIdentifier, Token} @@ -50,8 +51,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.util.Utils private[spark] class Client( @@ -69,11 +70,13 @@ private[spark] class Client( private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) - private val credentials = UserGroupInformation.getCurrentUser.getCredentials + private var credentials: Credentials = null private val amMemoryOverhead = args.amMemoryOverhead // MB private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + + private var loginFromKeytab = false private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) @@ -88,6 +91,8 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { + // Setup the credentials before doing anything else, so we have don't have issues at any point. + setupCredentials() yarnClient.init(yarnConf) yarnClient.start() @@ -219,12 +224,12 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val nns = getNameNodesToAccess(sparkConf) + dst + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials) // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) obtainTokenForHBase(hadoopConf, credentials) @@ -243,6 +248,20 @@ private[spark] class Client( "for alternatives.") } + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val localUri = new URI(args.keytab) + val localPath = getQualifiedLocalPath(localUri, hadoopConf) + val destinationPath = copyFileToRemote(dst, localPath, replication) + val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) + distCacheMgr.addResource( + destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, + sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) + } + def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -387,6 +406,28 @@ private[spark] class Client( } } + /** + * Get the renewal interval for tokens. + */ + private def getTokenRenewalInterval(stagingDirPath: Path): Long = { + // We cannot use the tokens generated above since those have renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = new Credentials() + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath + YarnSparkHadoopUtil.get.obtainTokensForNamenodes( + nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) + val t = creds.getAllTokens + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .head + val newExpiration = t.renew(hadoopConf) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal Interval set to $interval") + interval + } + /** * Set up the environment for launching our ApplicationMaster container. */ @@ -398,7 +439,16 @@ private[spark] class Client( env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - + if (loginFromKeytab) { + val remoteFs = FileSystem.get(hadoopConf) + val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir) + val credentialsFile = "credentials-" + UUID.randomUUID().toString + sparkConf.set( + "spark.yarn.credentials.file", new Path(stagingDirPath, credentialsFile).toString) + logInfo(s"Credentials file set to: $credentialsFile") + val renewalInterval = getTokenRenewalInterval(stagingDirPath) + sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) + } // Set the environment variables to be passed on to the executors. distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) @@ -463,7 +513,6 @@ private[spark] class Client( private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) : ContainerLaunchContext = { logInfo("Setting up container launch context for our AM") - val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) val localResources = prepareLocalResources(appStagingDir) @@ -638,6 +687,24 @@ private[spark] class Client( amContainer } + def setupCredentials(): Unit = { + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when principal is specified.") + logInfo("Attempting to login to the Kerberos" + + s" using principal: ${args.principal} and keytab: ${args.keytab}") + val f = new File(args.keytab) + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + val keytabFileName = f.getName + "-" + UUID.randomUUID().toString + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + loginFromKeytab = true + sparkConf.set("spark.yarn.keytab", keytabFileName) + sparkConf.set("spark.yarn.principal", args.principal) + logInfo("Successfully logged into the KDC.") + } + credentials = UserGroupInformation.getCurrentUser.getCredentials + } + /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, @@ -993,46 +1060,6 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) - /** - * Get the list of namenodes the user may access. - */ - private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "") - .split(",") - .map(_.trim()) - .filter(!_.isEmpty) - .map(new Path(_)) - .toSet - } - - private[yarn] def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - delegTokenRenewer - } - - /** - * Obtains tokens for the namenodes passed in and adds them to the credentials. - */ - private def obtainTokensForNamenodes( - paths: Set[Path], - conf: Configuration, - creds: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled()) { - val delegTokenRenewer = getTokenRenewer(conf) - paths.foreach { dst => - val dstFs = dst.getFileSystem(conf) - logDebug("getting token for namenode: " + dst) - dstFs.addDelegationTokens(delegTokenRenewer, creds) - } - } - } - /** * Obtains token for the Hive metastore and adds them to the credentials. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1423533470fc0..5653c9f14dc6d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -42,6 +42,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var amCores: Int = 1 var appName: String = "Spark" var priority = 0 + var principal: String = null + var keytab: String = null def isClusterMode: Boolean = userClass != null private var driverMemory: Int = 512 // MB @@ -231,6 +233,14 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) archives = value args = tail + case ("--principal") :: value :: tail => + principal = value + args = tail + + case ("--keytab") :: value :: tail => + keytab = value + args = tail + case Nil => case _ => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala new file mode 100644 index 0000000000000..229c2c4d5eb36 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.concurrent.{Executors, TimeUnit} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{ThreadUtils, Utils} + +import scala.util.control.NonFatal + +private[spark] class ExecutorDelegationTokenUpdater( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + @volatile private var lastCredentialsFileSuffix = 0 + + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + // On the executor, this thread wakes up and picks up new tokens from HDFS, if any. + private val executorUpdaterRunnable = + new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) + } + + def updateCredentialsIfRequired(): Unit = { + try { + val credentialsFilePath = new Path(credentialsFile) + val remoteFs = FileSystem.get(hadoopConf) + SparkHadoopUtil.get.listFilesSorted( + remoteFs, credentialsFilePath.getParent, + credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { credentialsStatus => + val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) + if (suffix > lastCredentialsFileSuffix) { + logInfo("Reading new delegation tokens from " + credentialsStatus.getPath) + val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) + lastCredentialsFileSuffix = suffix + UserGroupInformation.getCurrentUser.addCredentials(newCredentials) + logInfo("Tokens updated from credentials file.") + } else { + // Check every hour to see if new credentials arrived. + logInfo("Updated delegation tokens were expected, but the driver has not updated the " + + "tokens yet, will check again in an hour.") + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + return + } + } + val timeFromNowToRenewal = + SparkHadoopUtil.get.getTimeFromNowToRenewal( + sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials) + if (timeFromNowToRenewal <= 0) { + executorUpdaterRunnable.run() + } else { + logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.") + delegationTokenRenewer.schedule( + executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS) + } + } catch { + // Since the file may get deleted while we are reading it, catch the Exception and come + // back in an hour to try again + case NonFatal(e) => + logWarning("Error while trying to update credentials, will try again in 1 hour", e) + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + } + } + + private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { + val stream = remoteFs.open(tokenPath) + try { + val newCredentials = new Credentials() + newCredentials.readTokenStorageStream(stream) + newCredentials + } finally { + stream.close() + } + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5881dc5ffa3ad..ba91872107d0c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,18 +24,19 @@ import java.util.regex.Pattern import scala.collection.mutable.HashMap import scala.util.Try +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -43,6 +44,8 @@ import org.apache.spark.util.Utils */ class YarnSparkHadoopUtil extends SparkHadoopUtil { + private var tokenRenewer: Option[ExecutorDelegationTokenUpdater] = None + override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { dest.addCredentials(source.getCredentials()) } @@ -82,6 +85,57 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { if (credentials != null) credentials.getSecretKey(new Text(key)) else null } + /** + * Get the list of namenodes the user may access. + */ + def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "") + .split(",") + .map(_.trim()) + .filter(!_.isEmpty) + .map(new Path(_)) + .toSet + } + + def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + def obtainTokensForNamenodes( + paths: Set[Path], + conf: Configuration, + creds: Credentials, + renewer: Option[String] = None + ): Unit = { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = renewer.getOrElse(getTokenRenewer(conf)) + paths.foreach { dst => + val dstFs = dst.getFileSystem(conf) + logInfo("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + + private[spark] override def startExecutorDelegationTokenRenewer(sparkConf: SparkConf): Unit = { + tokenRenewer = Some(new ExecutorDelegationTokenUpdater(sparkConf, conf)) + tokenRenewer.get.updateCredentialsIfRequired() + } + + private[spark] override def stopExecutorDelegationTokenRenewer(): Unit = { + tokenRenewer.foreach(_.stop()) + } + } object YarnSparkHadoopUtil { @@ -100,6 +154,14 @@ object YarnSparkHadoopUtil { // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) + def get: YarnSparkHadoopUtil = { + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (!yarnMode) { + throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!") + } + SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil] + } /** * Add a path variable to the given environment map. * If the map already contains this key, append the value to the existing value instead. @@ -212,3 +274,4 @@ object YarnSparkHadoopUtil { classPathSeparatorField.get(null).asInstanceOf[String] } } + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index a51c2005cb472..508819e242a26 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -151,57 +151,6 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { } } - test("check access nns empty") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } - - test("check access nns unset") { - val sparkConf = new SparkConf() - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } - - test("check access nns") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) - } - - test("check access nns space") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) - } - - test("check access two nns") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = Client.getTokenRenewer(hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - Client.getTokenRenewer(hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") - } - object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 3877da4120e7c..d3c606e0ed998 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -86,6 +86,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, UTF_8) @@ -128,6 +129,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit override def afterAll() { yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") super.afterAll() } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 9395316b71ff4..e10b985c3c236 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -27,7 +29,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils @@ -173,4 +175,62 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { YarnSparkHadoopUtil.getClassPathSeparator() should be (":") } } + + test("check access nns empty") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val util = new YarnSparkHadoopUtil + val renewer = util.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val util = new YarnSparkHadoopUtil + val caught = + intercept[SparkException] { + util.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } } From 5c1fabafabdb87de7a92acbefbf294b24d0713fc Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 1 May 2015 14:42:58 -0700 Subject: [PATCH 11/91] Ignore flakey test in SparkSubmitUtilsSuite --- .../scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 1b2b699cb11e6..2df2597e058cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -117,7 +117,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } } - test("neglects Spark and Spark's dependencies") { + ignore("neglects Spark and Spark's dependencies") { val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") From 41c6a44b1a6ae5c70a8e8dc82d0062de9bdee5b3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 1 May 2015 16:47:00 -0700 Subject: [PATCH 12/91] [SPARK-7312][SQL] SPARK-6913 broke jdk6 build JIRA: https://issues.apache.org/jira/browse/SPARK-7312 Author: Yin Huai Closes #5847 from yhuai/jdbcJava6 and squashes the following commits: 68433a2 [Yin Huai] compile with Java 6 --- .../src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index ae9af1eabe68e..3a6c2c1e9101f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} import java.util.Properties import scala.collection.mutable @@ -195,7 +195,9 @@ package object jdbc { override def getMinorVersion: Int = wrapped.getMinorVersion - override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger + def getParentLogger: java.util.logging.Logger = + throw new SQLFeatureNotSupportedException( + s"${this.getClass().getName}.getParentLogger is not yet implemented.") override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) From e6fb37712eb1762d8184edc897bf2d468db8d254 Mon Sep 17 00:00:00 2001 From: "Rajendra Gokhale (rvgcentos)" Date: Fri, 1 May 2015 17:01:36 -0700 Subject: [PATCH 13/91] [SPARK-7304] [BUILD] Include $@ in call to mvn consistently in make-distribution.sh Adding the $ allows the caller of this script to supply additional arguments to the mvn command and is consistent with how mvn is being invoked elsewhere in the scripts Author: Rajendra Gokhale (rvgcentos) Closes #5846 from palamau/master and squashes the following commits: e5f2adb [Rajendra Gokhale (rvgcentos)] Add $@ in call to mvn consistently in make-distribution.sh --- make-distribution.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index cb65932b4abc0..92177e19fe6be 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -26,6 +26,7 @@ set -o pipefail set -e +set -x # Figure out where the Spark framework is installed SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" @@ -126,7 +127,7 @@ if [ ! $(command -v "$MVN") ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) From 98e7045805282988da907793a844fa53d4c293c9 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 1 May 2015 19:39:30 -0500 Subject: [PATCH 14/91] [SPARK-6999] [SQL] Remove the infinite recursive method (useless) Remove the method, since it causes infinite recursive calls. And seems it's a dummy method, since we have the API: `def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame` Author: Cheng Hao Closes #5804 from chenghao-intel/spark_6999 and squashes the following commits: 63220a8 [Cheng Hao] remove the infinite recursive method (useless) --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bd4a55fa132fb..5116fcefd4bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -428,20 +428,6 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } - /** - * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying - * a seq of names of columns to this RDD, the data type for each column will - * be inferred by the first row. - * - * @param rowRDD an JavaRDD of Row - * @param columns names for each column - * @return DataFrame - * @group dataframes - */ - def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = { - createDataFrame(rowRDD.rdd, columns.toSeq) - } - /** * Applies a schema to an RDD of Java Beans. * From ebc25a4ddfe07a67668217cec59893bc3b8cf730 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 1 May 2015 17:41:55 -0700 Subject: [PATCH 15/91] [SPARK-7309] [CORE] [STREAMING] Shutdown the thread pools in ReceivedBlockHandler and DAGScheduler Shutdown the thread pools in ReceivedBlockHandler and DAGScheduler when stopping them. Author: zsxwing Closes #5845 from zsxwing/SPARK-7309 and squashes the following commits: 6c004fd [zsxwing] Shutdown the thread pools in ReceivedBlockHandler and DAGScheduler --- .../src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 1 + .../apache/spark/streaming/receiver/ReceivedBlockHandler.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 05b8ab0d0a1f9..5d812918a13d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1399,6 +1399,7 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") + messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 4b3d9ee4b0090..651b534ac1900 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -190,6 +190,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( def stop() { writeAheadLog.close() + executionContext.shutdown() } } From b88c275e6ef6b17cd34d1c2c780b8959b41222c0 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 1 May 2015 17:46:06 -0700 Subject: [PATCH 16/91] [SPARK-7112][Streaming][WIP] Add a InputInfoTracker to track all the input streams Author: jerryshao Author: Saisai Shao Closes #5680 from jerryshao/SPARK-7111 and squashes the following commits: 339f854 [Saisai Shao] Add an end-to-end test 812bcaf [jerryshao] Continue address the comments abd0036 [jerryshao] Address the comments 727264e [jerryshao] Fix comment typo 6682bef [jerryshao] Fix compile issue 8325787 [jerryshao] Fix rebase issue 17fa251 [jerryshao] Refactor to build InputInfoTracker ee1b536 [jerryshao] Add DirectStreamTracker to track the direct streams --- .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/dstream/InputDStream.scala | 3 + .../dstream/ReceiverInputDStream.scala | 9 ++- .../spark/streaming/scheduler/BatchInfo.scala | 8 +- .../scheduler/InputInfoTracker.scala | 62 +++++++++++++++ .../streaming/scheduler/JobGenerator.scala | 8 +- .../streaming/scheduler/JobScheduler.scala | 4 + .../spark/streaming/scheduler/JobSet.scala | 4 +- .../spark/streaming/ui/BatchUIData.scala | 2 +- .../ui/StreamingJobProgressListener.scala | 31 +++++--- .../spark/streaming/ui/StreamingPage.scala | 4 +- .../spark/streaming/InputStreamsSuite.scala | 33 +++++++- .../streaming/StreamingListenerSuite.scala | 15 ++++ .../spark/streaming/TestSuiteBase.scala | 8 +- .../scheduler/InputInfoTrackerSuite.scala | 79 +++++++++++++++++++ .../StreamingJobProgressListenerSuite.scala | 19 ++--- 16 files changed, 247 insertions(+), 46 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 90c8b47aebce0..117cb59fb61c9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -159,7 +159,7 @@ class StreamingContext private[streaming] ( } } - private val nextReceiverInputStreamId = new AtomicInteger(0) + private val nextInputStreamId = new AtomicInteger(0) private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { @@ -241,7 +241,7 @@ class StreamingContext private[streaming] ( if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewReceiverStreamId() = nextReceiverInputStreamId.getAndIncrement() + private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() /** * Create an input stream with any arbitrary user implemented receiver. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index e652702e213ef..e4ad4b509d8d8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -41,6 +41,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) ssc.graph.addInputStream(this) + /** This is an unique identifier for the input stream. */ + val id = ssc.getNewInputStreamId() + /** * Checks whether the 'time' is valid wrt slideDuration for generating RDD. * Additionally it also ensures valid times are in strictly increasing order. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 4c7fd2c57c006..ba88416ef4009 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} -import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.scheduler.{InputInfo, ReceivedBlockInfo} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -39,9 +39,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** This is an unique identifier for the receiver input stream. */ - val id = ssc.getNewReceiverStreamId() - /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -72,6 +69,10 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockStoreResults = blockInfos.map { _.blockStoreResult } val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + // Register the input blocks information into InputInfoTracker + val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + // Check whether all the results are of the same type val resultTypes = blockStoreResults.map { _.getClass }.distinct if (resultTypes.size > 1) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 92dc113f397ca..5b9bfbf9b01e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,6 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch + * @param streamIdToNumRecords A map of input stream id to record number * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -32,7 +33,7 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]], + streamIdToNumRecords: Map[Int, Long], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] @@ -58,4 +59,9 @@ case class BatchInfo( */ def totalDelay: Option[Long] = schedulingDelay.zip(processingDelay) .map(x => x._1 + x._2).headOption + + /** + * The number of recorders received by the receivers in this batch. + */ + def numRecords: Long = streamIdToNumRecords.values.sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala new file mode 100644 index 0000000000000..a72efccf2f994 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.streaming.{Time, StreamingContext} + +/** To track the information of input stream at specified batch time. */ +private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) + +/** + * This class manages all the input streams as well as their input data statistics. The information + * will be exposed through StreamingListener for monitoring. + */ +private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { + + // Map to track all the InputInfo related to specific batch time and input stream. + private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + + /** Report the input information with batch time to the tracker */ + def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, + new mutable.HashMap[Int, InputInfo]()) + + if (inputInfos.contains(inputInfo.inputStreamId)) { + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + s"$batchTime is already added into InputInfoTracker, this is a illegal state") + } + inputInfos += ((inputInfo.inputStreamId, inputInfo)) + } + + /** Get the all the input stream's information of specified batch time */ + def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + val inputInfos = batchTimeToInputInfos.get(batchTime) + // Convert mutable HashMap to immutable Map for the caller + inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + } + + /** Cleanup the tracked input information older than threshold batch time */ + def cleanup(batchThreshTime: Time): Unit = synchronized { + val timesToCleanup = batchTimeToInputInfos.keys.filter(_ < batchThreshTime) + logInfo(s"remove old batch metadata: ${timesToCleanup.mkString(" ")}") + batchTimeToInputInfos --= timesToCleanup + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2467d50839add..9f93d6cbc3c20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -243,9 +243,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.generateJobs(time) // generate jobs using allocated block } match { case Success(jobs) => - val receivedBlockInfos = - jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) + val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) + val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -266,6 +266,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // checkpointing of this batch to complete. val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -278,6 +279,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // been saved to checkpoints, so its safe to delete block metadata and data WAL files val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index c7a2c1141a128..1d1ddaaccf217 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -50,6 +50,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // These two are created only when scheduler starts. // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null + // A tracker to track all the input stream information as well as processed record number + var inputInfoTracker: InputInfoTracker = null + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { @@ -65,6 +68,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) + inputInfoTracker = new InputInfoTracker(ssc) receiverTracker.start() jobGenerator.start() logInfo("Started JobScheduler") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 24b3794236ea5..e6be63b2ddbdc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { + streamIdToNumRecords: Map[Int, Long] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - receivedBlockInfo, + streamIdToNumRecords, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index f45c291b7c0fe..99e10d2b0be12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.receivedBlockInfo.mapValues(_.map(_.numRecords).sum), + batchInfo.streamIdToNumRecords, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 34b55717a1db2..d2729fa70d6d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -188,25 +188,26 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized { - val latestBatches = retainedBatches.reverse.take(batchUIDataLimit) - (0 until numReceivers).map { receiverId => - val recordsOfParticularReceiver = latestBatches.map { batch => - // calculate records per second for each batch - batch.receiverNumRecords.get(receiverId).sum.toDouble * 1000 / batchDuration - } - val distributionOption = Distribution(recordsOfParticularReceiver) - (receiverId, distributionOption) + val latestBatchInfos = retainedBatches.reverse.take(batchUIDataLimit) + val latestReceiverNumRecords = latestBatchInfos.map(_.receiverNumRecords) + val streamIds = ssc.graph.getInputStreams().map(_.id) + streamIds.map { id => + val recordsOfParticularReceiver = + latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration) + val distribution = Distribution(recordsOfParticularReceiver) + (id, distribution) }.toMap } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receiverNumRecords) - lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => - (0 until numReceivers).map { receiverId => - (receiverId, lastReceivedBlockInfo.getOrElse(receiverId, 0L)) + val lastReceiverNumRecords = lastReceivedBatch.map(_.receiverNumRecords) + val streamIds = ssc.graph.getInputStreams().map(_.id) + lastReceiverNumRecords.map { receiverNumRecords => + streamIds.map { id => + (id, receiverNumRecords.getOrElse(id, 0L)) }.toMap }.getOrElse { - (0 until numReceivers).map(receiverId => (receiverId, 0L)).toMap + streamIds.map(id => (id, 0L)).toMap } } @@ -214,6 +215,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) receiverInfos.get(receiverId) } + def receiverIds(): Iterable[Int] = synchronized { + receiverInfos.keys + } + def lastCompletedBatch: Option[BatchUIData] = synchronized { completedBatchUIData.sortBy(_.batchTime)(Time.ordering).lastOption } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 07fa285642eec..db37ae815bdf5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -95,7 +95,7 @@ private[ui] class StreamingPage(parent: StreamingTab) "Maximum rate\n[events/sec]", "Last Error" ) - val dataRows = (0 until listener.numReceivers).map { receiverId => + val dataRows = listener.receiverIds().map { receiverId => val receiverInfo = listener.receiverInfo(receiverId) val receiverName = receiverInfo.map(_.name).getOrElse(s"Receiver-$receiverId") val receiverActive = receiverInfo.map { info => @@ -114,7 +114,7 @@ private[ui] class StreamingPage(parent: StreamingTab) }.getOrElse(emptyCell) Seq(receiverName, receiverActive, receiverLocation, receiverLastBatchRecords) ++ receivedRecordStats ++ Seq(receiverLastError) - } + }.toSeq Some(listingTable(headerRow, dataRows)) } else { None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index e6ac4975c5e68..eb136758249d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -27,17 +27,18 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQu import scala.language.postfixOps import com.google.common.io.Files +import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.rdd.RDD -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.hadoop.fs.Path class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -278,6 +279,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("test track the number of input stream") { + val ssc = new StreamingContext(conf, batchDuration) + + class TestInputDStream extends InputDStream[String](ssc) { + def start() { } + def stop() { } + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } + + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + + assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } + def testFileStream(newFilesOnly: Boolean) { val testDir: File = null try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 9020be166acf0..312cce408cfe7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -57,6 +57,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay should be (None) }) + batchInfosSubmitted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) // SPARK-6766: processingStartTime of batch info should not be None when starting @@ -70,6 +75,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay should be (None) }) + batchInfosStarted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) @@ -86,6 +96,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) + batchInfosCompleted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index c3cae8aeb6d15..2ba86aeaf9d51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -29,10 +29,10 @@ import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} /** @@ -57,6 +57,10 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], return None } + // Report the input data's information to InputInfoTracker for testing + val inputInfo = InputInfo(id, selectedInput.length.toLong) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala new file mode 100644 index 0000000000000..5478b41845943 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Time, Duration, StreamingContext} + +class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { + + private var ssc: StreamingContext = _ + + before { + val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") + if (ssc == null) { + ssc = new StreamingContext(conf, Duration(1000)) + } + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + } + + test("test report and get InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val streamId2 = 1 + val time = Time(0L) + val inputInfo1 = InputInfo(streamId1, 100L) + val inputInfo2 = InputInfo(streamId2, 300L) + inputInfoTracker.reportInfo(time, inputInfo1) + inputInfoTracker.reportInfo(time, inputInfo2) + + val batchTimeToInputInfos = inputInfoTracker.getInfo(time) + assert(batchTimeToInputInfos.size == 2) + assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2)) + assert(batchTimeToInputInfos(streamId1) === inputInfo1) + assert(batchTimeToInputInfos(streamId2) === inputInfo2) + assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1) + } + + test("test cleanup InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val inputInfo1 = InputInfo(streamId1, 100L) + val inputInfo2 = InputInfo(streamId1, 300L) + inputInfoTracker.reportInfo(Time(0), inputInfo1) + inputInfoTracker.reportInfo(Time(1), inputInfo2) + + inputInfoTracker.cleanup(Time(0)) + assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + + inputInfoTracker.cleanup(Time(1)) + assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index fa89536de4054..e874536e63518 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,13 +49,10 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val receivedBlockInfo = Map( - 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), - 1 -> Array(ReceivedBlockInfo(1, 300, null)) - ) + val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -67,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -106,7 +103,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -144,11 +141,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) val listener = new StreamingJobProgressListener(ssc) - val receivedBlockInfo = Map( - 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), - 1 -> Array(ReceivedBlockInfo(1, 300, null)) - ) - val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) From 4786484076865c56c3fc23c49819b9be2933d287 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Fri, 1 May 2015 17:54:56 -0700 Subject: [PATCH 17/91] [SPARK-2808][Streaming][Kafka] update kafka to 0.8.2 i don't think this should be merged until after 1.3.0 is final Author: cody koeninger Author: Helena Edelson Closes #4537 from koeninger/wip-2808-kafka-0.8.2-upgrade and squashes the following commits: 803aa2c [cody koeninger] [SPARK-2808][Streaming][Kafka] code cleanup per TD e6dfaf6 [cody koeninger] [SPARK-2808][Streaming][Kafka] pointless whitespace change to trigger jenkins again 1770abc [cody koeninger] [SPARK-2808][Streaming][Kafka] make waitUntilLeaderOffset easier to call, call it from python tests as well d4267e9 [cody koeninger] [SPARK-2808][Streaming][Kafka] fix stderr redirect in python test script 30d991d [cody koeninger] [SPARK-2808][Streaming][Kafka] remove stderr prints since it breaks python 3 syntax 1d896e2 [cody koeninger] [SPARK-2808][Streaming][Kafka] add even even more logging to python test 4c4557f [cody koeninger] [SPARK-2808][Streaming][Kafka] add even more logging to python test 115aeee [cody koeninger] Merge branch 'master' into wip-2808-kafka-0.8.2-upgrade 2712649 [cody koeninger] [SPARK-2808][Streaming][Kafka] add more logging to python test, see why its timing out in jenkins 2b92d3f [cody koeninger] [SPARK-2808][Streaming][Kafka] wait for leader offsets in the java test as well 3824ce3 [cody koeninger] [SPARK-2808][Streaming][Kafka] naming / comments per tdas 61b3464 [cody koeninger] [SPARK-2808][Streaming][Kafka] delay for second send in boundary condition test af6f3ec [cody koeninger] [SPARK-2808][Streaming][Kafka] delay test until latest leader offset matches expected value 9edab4c [cody koeninger] [SPARK-2808][Streaming][Kafka] more shots in the dark on jenkins failing test c70ee43 [cody koeninger] [SPARK-2808][Streaming][Kafka] add more asserts to test, try to figure out why it fails on jenkins but not locally 1d10751 [cody koeninger] Merge branch 'master' into wip-2808-kafka-0.8.2-upgrade ed02d2c [cody koeninger] [SPARK-2808][Streaming][Kafka] move default argument for api version to overloaded method, for binary compat 407382e [cody koeninger] [SPARK-2808][Streaming][Kafka] update kafka to 0.8.2.1 77de6c2 [cody koeninger] Merge branch 'master' into wip-2808-kafka-0.8.2-upgrade 6953429 [cody koeninger] [SPARK-2808][Streaming][Kafka] update kafka to 0.8.2 2e67c66 [Helena Edelson] #SPARK-2808 Update to Kafka 0.8.2.0 GA from beta. d9dc2bc [Helena Edelson] Merge remote-tracking branch 'upstream/master' into wip-2808-kafka-0.8.2-upgrade e768164 [Helena Edelson] #2808 update kafka to version 0.8.2 --- external/kafka/pom.xml | 2 +- .../spark/streaming/kafka/KafkaCluster.scala | 51 +++++++++++++++---- .../streaming/kafka/KafkaTestUtils.scala | 35 +++++++++++-- .../streaming/kafka/JavaKafkaRDDSuite.java | 3 ++ .../spark/streaming/kafka/KafkaRDDSuite.scala | 32 ++++++++---- python/pyspark/streaming/tests.py | 8 +-- 6 files changed, 104 insertions(+), 27 deletions(-) diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index f695cff410a18..243ce6eaca658 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -44,7 +44,7 @@ org.apache.kafka kafka_${scala.binary.version} - 0.8.1.1 + 0.8.2.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index bd767031c1849..6cf254a7b69cb 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -20,9 +20,10 @@ package org.apache.spark.streaming.kafka import scala.util.control.NonFatal import scala.util.Random import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import java.util.Properties import kafka.api._ -import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition} +import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} import kafka.consumer.{ConsumerConfig, SimpleConsumer} import org.apache.spark.SparkException @@ -220,12 +221,22 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI // scalastyle:on + // this 0 here indicates api version, in this case the original ZK backed api. + private def defaultConsumerApiVersion: Short = 0 + /** Requires Kafka >= 0.8.1.1 */ def getConsumerOffsets( groupId: String, topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = + getConsumerOffsets(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Long]] = { - getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r => + getConsumerOffsetMetadata(groupId, topicAndPartitions, consumerApiVersion).right.map { r => r.map { kv => kv._1 -> kv._2.offset } @@ -236,9 +247,16 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { def getConsumerOffsetMetadata( groupId: String, topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = + getConsumerOffsetMetadata(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { var result = Map[TopicAndPartition, OffsetMetadataAndError]() - val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq) + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq, consumerApiVersion) val errs = new Err withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp = consumer.fetchOffsets(req) @@ -266,24 +284,39 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { def setConsumerOffsets( groupId: String, offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsets(groupId, offsets, defaultConsumerApiVersion) + + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Short]] = { - setConsumerOffsetMetadata(groupId, offsets.map { kv => - kv._1 -> OffsetMetadataAndError(kv._2) - }) + val meta = offsets.map { kv => + kv._1 -> OffsetAndMetadata(kv._2) + } + setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) } /** Requires Kafka >= 0.8.1.1 */ def setConsumerOffsetMetadata( groupId: String, - metadata: Map[TopicAndPartition, OffsetMetadataAndError] + metadata: Map[TopicAndPartition, OffsetAndMetadata] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsetMetadata(groupId, metadata, defaultConsumerApiVersion) + + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetAndMetadata], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Short]] = { var result = Map[TopicAndPartition, Short]() - val req = OffsetCommitRequest(groupId, metadata) + val req = OffsetCommitRequest(groupId, metadata, consumerApiVersion) val errs = new Err val topicAndPartitions = metadata.keySet withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp = consumer.commitOffsets(req) - val respMap = resp.requestInfo + val respMap = resp.commitStatus val needed = topicAndPartitions.diff(result.keySet) needed.foreach { tp: TopicAndPartition => respMap.get(tp).foreach { err: Short => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 13e9475065979..6dc4e9517d5a4 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -29,10 +29,12 @@ import scala.language.postfixOps import scala.util.control.NonFatal import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.common.TopicAndPartition import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer +import kafka.utils.{ZKStringSerializer, ZkUtils} import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.I0Itec.zkclient.ZkClient @@ -227,12 +229,35 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + /** Wait until the leader offset for the given topic/partition equals the specified offset */ + def waitUntilLeaderOffset( + topic: String, + partition: Int, + offset: Long): Unit = { eventually(Time(10000), Time(100)) { + val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) + val tp = TopicAndPartition(topic, partition) + val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset assert( - server.apis.metadataCache.containsTopicAndPartition(topic, partition), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) + llo == offset, + s"$topic $partition $offset not reached after timeout") + } + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + ZkUtils.getLeaderForPartition(zkClient, topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(Time(10000), Time(100)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613ca..5cf379635354f 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,6 +72,9 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); + kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); + OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 7d26ce50875b3..39c3fb448ff57 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -53,14 +53,15 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { } test("basic usage") { - val topic = "topicbasic" + val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) val messages = Set("the", "quick", "brown", "fox") kafkaTestUtils.sendMessages(topic, messages.toArray) - val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + "group.id" -> s"test-consumer-${Random.nextInt}") + + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -73,27 +74,38 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd - val topic = "topic1" + val topic = s"topicboundary-${Random.nextInt}" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) kafkaTestUtils.createTopic(topic) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + "group.id" -> s"test-consumer-${Random.nextInt}") val kc = new KafkaCluster(kafkaParams) // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) + val sentCount = sent.values.sum + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) + // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) assert(rdd.isDefined) - assert(rdd.get.count === sent.values.sum, "didn't get all sent messages") - val ranges = rdd.get.asInstanceOf[HasOffsetRanges] - .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + val ranges = rdd.get.asInstanceOf[HasOffsetRanges].offsetRanges + val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum - kc.setConsumerOffsets(kafkaParams("group.id"), ranges) + assert(rangeCount === sentCount, "offset range didn't include all sent messages") + assert(rdd.get.count === sentCount, "didn't get all sent messages") + + val rangesMap = ranges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + // make sure consumer offsets are committed before the next getRdd call + kc.setConsumerOffsets(kafkaParams("group.id"), rangesMap).fold( + err => throw new Exception(err.mkString("\n")), + _ => () + ) // this is the "0 messages" case val rdd2 = getRdd(kc, Set(topic)) @@ -101,6 +113,8 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) + assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7c06c203455d9..33ea8c9293d74 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -606,7 +606,6 @@ def _validateRddResult(self, sendData, rdd): result = {} for i in rdd.map(lambda x: x[1]).collect(): result[i] = result.get(i, 0) + 1 - self.assertEqual(sendData, result) def test_kafka_stream(self): @@ -616,6 +615,7 @@ def test_kafka_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, @@ -631,6 +631,7 @@ def test_kafka_direct_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) @@ -645,6 +646,7 @@ def test_kafka_direct_stream_from_offset(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) @@ -659,7 +661,7 @@ def test_kafka_rdd(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) @@ -675,7 +677,7 @@ def test_kafka_rdd_with_leaders(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) From ae98eec730125c1153dcac9ea941959cc79e4f42 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 1 May 2015 18:02:10 -0700 Subject: [PATCH 18/91] [SPARK-3444] Provide an easy way to change log level Add support for changing the log level at run time through the SparkContext. Based on an earlier PR, #2433 includes CR feedback from pwendel & davies Author: Holden Karau Closes #5791 from holdenk/SPARK-3444-provide-an-easy-way-to-change-log-level-r2 and squashes the following commits: 3bf3be9 [Holden Karau] fix exception 42ba873 [Holden Karau] fix exception 9117244 [Holden Karau] Only allow valid log levels, throw exception if invalid log level. 338d7bf [Holden Karau] rename setLoggingLevel to setLogLevel fac14a0 [Holden Karau] Fix style errors d9d03f3 [Holden Karau] Add support for changing the log level at run time through the SparkContext. Based on an earlier PR, #2433 includes CR feedback from @pwendel & @davies --- .../scala/org/apache/spark/SparkContext.scala | 13 ++++++ .../spark/api/java/JavaSparkContext.scala | 8 ++++ .../scala/org/apache/spark/util/Utils.scala | 7 ++++ .../org/apache/spark/util/UtilsSuite.scala | 40 ++++++++++++------- python/pyspark/context.py | 7 ++++ python/pyspark/sql/dataframe.py | 2 +- 6 files changed, 61 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3f7cba6dbcdb5..4ef90546a2452 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -347,6 +347,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli value } + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN") + if (!validLevels.contains(logLevel)) { + throw new IllegalArgumentException( + s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}") + } + Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel)) + } + try { _conf = config.clone() _conf.validateSettings() diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 3be6783bba49d..02e49a853c5f7 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -755,6 +755,14 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + sc.setLogLevel(logLevel) + } + /** * Assigns a group ID to all the jobs started by this thread until the group ID is set to a * different value or cleared. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4b5a5df5ef7b7..844f0cd22d95d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2022,6 +2022,13 @@ private[spark] object Utils extends Logging { } } + /** + * configure a new log4j level + */ + def setLogLevel(l: org.apache.log4j.Level) { + org.apache.log4j.Logger.getRootLogger().setLevel(l) + } + /** * config a log4j properties used for testsuite */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 62a3cbcdf69ea..651ead6ff1de2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -35,9 +35,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.network.util.ByteUnit +import org.apache.spark.Logging import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite with ResetSystemProperties { +class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { test("timeConversion") { // Test -1 @@ -68,7 +69,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { intercept[NumberFormatException] { Utils.timeStringAsMs("600l") } - + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -99,7 +100,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsGb("1k") === 0) assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) - + assert(Utils.byteStringAsMb("1") === 1) assert(Utils.byteStringAsMb("1m") === 1) assert(Utils.byteStringAsMb("1048575b") === 0) @@ -118,7 +119,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) - + assert(Utils.byteStringAsBytes("1") === 1) assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) @@ -127,17 +128,17 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes - // This demonstrates that we can have e.g 1024^3 PB without overflowing. + // This demonstrates that we can have e.g 1024^3 PB without overflowing. assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) - + // Run this to confirm it doesn't throw an exception - assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) - + // Test overflow exception intercept[IllegalArgumentException] { - // This value exceeds Long.MAX when converted to bytes + // This value exceeds Long.MAX when converted to bytes Utils.byteStringAsBytes("9223372036854775808") } @@ -146,22 +147,22 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { // This value exceeds Long.MAX when converted to TB ByteUnit.PiB.toTiB(9223372036854775807L) } - + // Test fractional string intercept[NumberFormatException] { Utils.byteStringAsMb("0.064") } - + // Test fractional string intercept[NumberFormatException] { Utils.byteStringAsMb("0.064m") } - + // Test invalid strings intercept[NumberFormatException] { Utils.byteStringAsBytes("500ub") } - + // Test invalid strings intercept[NumberFormatException] { Utils.byteStringAsBytes("This breaks 600b") @@ -174,12 +175,12 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { intercept[NumberFormatException] { Utils.byteStringAsBytes("600gb This breaks") } - + intercept[NumberFormatException] { Utils.byteStringAsBytes("This 123mb breaks") } } - + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -475,6 +476,15 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + // Test for using the util function to change our log levels. + test("log4j log level change") { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } + test("deleteRecursively") { val tempDir1 = Utils.createTempDir() assert(tempDir1.exists()) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b006120eb266d..31992795a9e45 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -267,6 +267,13 @@ def __exit__(self, type, value, trace): """ self.stop() + def setLogLevel(self, logLevel): + """ + Control our logLevel. This overrides any user-defined log settings. + Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + """ + self._jsc.setLogLevel(logLevel) + @classmethod def setSystemProperty(cls, key, value): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1f08c2df9305b..5ff49cac5522b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1,6 +1,6 @@ # # Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with +# contir[butor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with From 099327d5376554134c9af49bc2045add4cfb024d Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 1 May 2015 18:32:46 -0700 Subject: [PATCH 19/91] [SPARK-6954] [YARN] ExecutorAllocationManager can end up requesting a negative n... ...umber of executors Author: Sandy Ryza Closes #5704 from sryza/sandy-spark-6954 and squashes the following commits: b7890fb [Sandy Ryza] Avoid ramping up to an existing number of executors 6eb516a [Sandy Ryza] SPARK-6954. ExecutorAllocationManager can end up requesting a negative number of executors --- .../spark/ExecutorAllocationManager.scala | 101 +++++------ .../ExecutorAllocationManagerSuite.scala | 171 ++++++++++-------- 2 files changed, 148 insertions(+), 124 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index b986fa87dc2f4..228d9149df2a2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -27,13 +27,20 @@ import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. * - * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If - * the scheduler queue is not drained in N seconds, then new executors are added. If the queue - * persists for another M seconds, then more executors are added and so on. The number added - * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. The upper bound is based both on a configured property - * and on the number of tasks pending: the policy will never increase the number of executor - * requests past the number needed to handle all pending tasks. + * The ExecutorAllocationManager maintains a moving target number of executors which is periodically + * synced to the cluster manager. The target starts at a configured initial value and changes with + * the number of pending and running tasks. + * + * Decreasing the target number of executors happens when the current target is more than needed to + * handle the current load. The target number of executors is always truncated to the number of + * executors that could run all current running and pending tasks at once. + * + * Increasing the target number of executors happens in response to backlogged tasks waiting to be + * scheduled. If the scheduler queue is not drained in N seconds, then new executors are added. If + * the queue persists for another M seconds, then more executors are added and so on. The number + * added in each round increases exponentially from the previous round until an upper bound has been + * reached. The upper bound is based both on a configured property and on the current number of + * running and pending tasks, as described above. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -105,8 +112,10 @@ private[spark] class ExecutorAllocationManager( // Number of executors to add in the next round private var numExecutorsToAdd = 1 - // Number of executors that have been requested but have not registered yet - private var numExecutorsPending = 0 + // The desired number of executors at this moment in time. If all our executors were to die, this + // is the number of executors we would immediately want from the cluster manager. + private var numExecutorsTarget = + conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -199,13 +208,6 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } - /** - * The number of executors we would have if the cluster manager were to fulfill all our existing - * requests. - */ - private def targetNumExecutors(): Int = - numExecutorsPending + executorIds.size - executorsPendingToRemove.size - /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. @@ -227,7 +229,7 @@ private[spark] class ExecutorAllocationManager( private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - addOrCancelExecutorRequests(now) + updateAndSyncNumExecutorsTarget(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime @@ -239,26 +241,28 @@ private[spark] class ExecutorAllocationManager( } /** + * Updates our target number of executors and syncs the result with the cluster manager. + * * Check to see whether our existing allocation and the requests we've made previously exceed our - * current needs. If so, let the cluster manager know so that it can cancel pending requests that - * are unneeded. + * current needs. If so, truncate our target and let the cluster manager know so that it can + * cancel pending requests that are unneeded. * * If not, and the add time has expired, see if we can request new executors and refresh the add * time. * * @return the delta in the target number of executors. */ - private def addOrCancelExecutorRequests(now: Long): Int = synchronized { - val currentTarget = targetNumExecutors + private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < currentTarget) { + if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new - // executors and inform the cluster manager to cancel the extra pending requests. - val newTotalExecutors = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(newTotalExecutors) + // executors and inform the cluster manager to cancel the extra pending requests + val oldNumExecutorsTarget = numExecutorsTarget + numExecutorsTarget = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 - updateNumExecutorsPending(newTotalExecutors) + numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + @@ -281,21 +285,30 @@ private[spark] class ExecutorAllocationManager( */ private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound - val currentTarget = targetNumExecutors - if (currentTarget >= maxNumExecutors) { - logDebug(s"Not adding executors because there are already ${executorIds.size} " + - s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") + if (numExecutorsTarget >= maxNumExecutors) { + val numExecutorsPending = numExecutorsTarget - executorIds.size + logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + + s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded) - val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors) - val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors) + val oldNumExecutorsTarget = numExecutorsTarget + // There's no point in wasting time ramping up to the number of executors we already have, so + // make sure our target is at least as much as our current allocation: + numExecutorsTarget = math.max(numExecutorsTarget, executorIds.size) + // Boost our target with the number to add for this round: + numExecutorsTarget += numExecutorsToAdd + // Ensure that our target doesn't exceed what we need at the present moment: + numExecutorsTarget = math.min(numExecutorsTarget, maxNumExecutorsNeeded) + // Ensure that our target fits within configured bounds: + numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = updateNumExecutorsPending(newTotalExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + - s" (new desired total will be $newTotalExecutors)") + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 } else { @@ -304,23 +317,11 @@ private[spark] class ExecutorAllocationManager( delta } else { logWarning( - s"Unable to reach the cluster manager to request $newTotalExecutors total executors!") + s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") 0 } } - /** - * Given the new target number of executors, update the number of pending executor requests, - * and return the delta from the old number of pending requests. - */ - private def updateNumExecutorsPending(newTotalExecutors: Int): Int = { - val newNumExecutorsPending = - newTotalExecutors - executorIds.size + executorsPendingToRemove.size - val delta = newNumExecutorsPending - numExecutorsPending - numExecutorsPending = newNumExecutorsPending - delta - } - /** * Request the cluster manager to remove the given executor. * Return whether the request is received. @@ -372,10 +373,6 @@ private[spark] class ExecutorAllocationManager( // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951) executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle) logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") - if (numExecutorsPending > 0) { - numExecutorsPending -= 1 - logDebug(s"Decremented number of pending executors ($numExecutorsPending left)") - } } else { logWarning(s"Duplicate executor $executorId has registered") } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 22acc270b983e..49e6de4e0bafd 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -78,7 +78,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit test("starting state") { sc = createSparkContext() val manager = sc.executorAllocationManager.get - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 1) assert(executorsPendingToRemove(manager).isEmpty) assert(executorIds(manager).isEmpty) assert(addTime(manager) === ExecutorAllocationManager.NOT_SET) @@ -91,108 +91,108 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 2) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 4) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 4) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 8) - assert(addExecutors(manager) === 3) // reached the limit of 10 - assert(numExecutorsPending(manager) === 10) + assert(addExecutors(manager) === 2) // reached the limit of 10 + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 10) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) // Register previously requested executors onExecutorAdded(manager, "first") - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) onExecutorAdded(manager, "second") onExecutorAdded(manager, "third") onExecutorAdded(manager, "fourth") - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) onExecutorAdded(manager, "first") // duplicates should not count onExecutorAdded(manager, "second") - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) // Try adding again // This should still fail because the number pending + running is still at the limit assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) } test("add executors capped by num pending tasks") { - sc = createSparkContext(1, 10) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) // Verify that we're capped at number of tasks in the stage - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 3) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 5) + assert(numExecutorsTarget(manager) === 5) assert(numExecutorsToAdd(manager) === 1) - // Verify that running a task reduces the cap + // Verify that running a task doesn't affect the target sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) sc.listenerBus.postToAll(SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) - assert(numExecutorsPending(manager) === 4) + assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 5) + assert(numExecutorsTarget(manager) === 6) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 1) - // Verify that re-running a task doesn't reduce the cap further + // Verify that re-running a task doesn't blow things up sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 8) + assert(numExecutorsTarget(manager) === 9) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) } test("cancel pending executors when no longer needed") { - sc = createSparkContext(1, 10) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) @@ -266,7 +266,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit // Add a few executors assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 4) onExecutorAdded(manager, "1") onExecutorAdded(manager, "2") onExecutorAdded(manager, "3") @@ -274,55 +273,57 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit onExecutorAdded(manager, "5") onExecutorAdded(manager, "6") onExecutorAdded(manager, "7") - assert(executorIds(manager).size === 7) + onExecutorAdded(manager, "8") + assert(executorIds(manager).size === 8) // Remove until limit assert(removeExecutor(manager, "1")) assert(removeExecutor(manager, "2")) - assert(!removeExecutor(manager, "3")) // lower limit reached - assert(!removeExecutor(manager, "4")) + assert(removeExecutor(manager, "3")) + assert(!removeExecutor(manager, "4")) // lower limit reached + assert(!removeExecutor(manager, "5")) onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") assert(executorIds(manager).size === 5) // Add until limit - assert(addExecutors(manager) === 5) // upper limit reached + assert(addExecutors(manager) === 2) // upper limit reached assert(addExecutors(manager) === 0) - assert(!removeExecutor(manager, "3")) // still at lower limit - assert(!removeExecutor(manager, "4")) - onExecutorAdded(manager, "8") + assert(!removeExecutor(manager, "4")) // still at lower limit + assert(!removeExecutor(manager, "5")) onExecutorAdded(manager, "9") onExecutorAdded(manager, "10") onExecutorAdded(manager, "11") onExecutorAdded(manager, "12") + onExecutorAdded(manager, "13") assert(executorIds(manager).size === 10) // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutor(manager, "3")) assert(removeExecutor(manager, "4")) assert(removeExecutor(manager, "5")) assert(removeExecutor(manager, "6")) + assert(removeExecutor(manager, "7")) assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 1) - onExecutorRemoved(manager, "3") + assert(addExecutors(manager) === 0) onExecutorRemoved(manager, "4") + onExecutorRemoved(manager, "5") assert(executorIds(manager).size === 8) - // Add succeeds again, now that we are no longer at the upper limit - // Number of executors added restarts at 1 - assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 1) // upper limit reached + // Number of executors pending restarts at 1 + assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) assert(executorIds(manager).size === 8) - onExecutorRemoved(manager, "5") onExecutorRemoved(manager, "6") - onExecutorAdded(manager, "13") + onExecutorRemoved(manager, "7") onExecutorAdded(manager, "14") + onExecutorAdded(manager, "15") assert(executorIds(manager).size === 8) assert(addExecutors(manager) === 0) // still at upper limit - onExecutorAdded(manager, "15") onExecutorAdded(manager, "16") + onExecutorAdded(manager, "17") assert(executorIds(manager).size === 10) + assert(numExecutorsTarget(manager) === 10) } test("starting/canceling add timer") { @@ -405,33 +406,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("mock polling loop with no events") { - sc = createSparkContext(1, 20) + sc = createSparkContext(0, 20) val manager = sc.executorAllocationManager.get val clock = new ManualClock(2020L) manager.setClock(clock) // No events - we should not be adding or removing - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(100L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(1000L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(10000L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) } test("mock polling loop add behavior") { - sc = createSparkContext(1, 20) + sc = createSparkContext(0, 20) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -441,43 +442,43 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsPending(manager) === 0) // timer not exceeded yet + assert(numExecutorsTarget(manager) === 0) // timer not exceeded yet clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1) // first timer exceeded + assert(numExecutorsTarget(manager) === 1) // first timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet + assert(numExecutorsTarget(manager) === 1) // second timer not exceeded yet clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded + assert(numExecutorsTarget(manager) === 1 + 2) // second timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded + assert(numExecutorsTarget(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7) // timer is canceled + assert(numExecutorsTarget(manager) === 7) // timer is canceled clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1) // timer restarted + assert(numExecutorsTarget(manager) === 7 + 1) // timer restarted clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1 + 2) + assert(numExecutorsTarget(manager) === 7 + 1 + 2) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) + assert(numExecutorsTarget(manager) === 7 + 1 + 2 + 4) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 20) // limit reached + assert(numExecutorsTarget(manager) === 20) // limit reached } test("mock polling loop remove behavior") { @@ -671,6 +672,31 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit assert(!removeTimes(manager).contains("executor-1")) } + test("avoid ramp up when target < running executors") { + sc = createSparkContext(0, 100000) + val manager = sc.executorAllocationManager.get + val stage1 = createStageInfo(0, 1000) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) + + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 4) + assert(addExecutors(manager) === 8) + assert(numExecutorsTarget(manager) === 15) + (0 until 15).foreach { i => + onExecutorAdded(manager, s"executor-$i") + } + assert(executorIds(manager).size === 15) + sc.listenerBus.postToAll(SparkListenerStageCompleted(stage1)) + + adjustRequestedExecutors(manager) + assert(numExecutorsTarget(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 1000))) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 16) + } + private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() .setMaster("local") @@ -713,7 +739,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { * ------------------------------------------------------- */ private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) - private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _numExecutorsTarget = PrivateMethod[Int]('numExecutorsTarget) private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded) private val _executorsPendingToRemove = PrivateMethod[collection.Set[String]]('executorsPendingToRemove) @@ -722,7 +748,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) private val _schedule = PrivateMethod[Unit]('schedule) private val _addExecutors = PrivateMethod[Int]('addExecutors) - private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests) + private val _updateAndSyncNumExecutorsTarget = + PrivateMethod[Int]('updateAndSyncNumExecutorsTarget) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) @@ -735,8 +762,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _numExecutorsToAdd() } - private def numExecutorsPending(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _numExecutorsPending() + private def numExecutorsTarget(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _numExecutorsTarget() } private def executorsPendingToRemove( @@ -766,7 +793,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { } private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _addOrCancelExecutorRequests(0L) + manager invokePrivate _updateAndSyncNumExecutorsTarget(0L) } private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { From 2022193412e832393a29b94609841c3ffe8e3d66 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Fri, 1 May 2015 18:36:42 -0700 Subject: [PATCH 20/91] [SPARK-7216] [MESOS] Add driver details page to Mesos cluster UI. Add a details page that displays Mesos driver in the Mesos cluster UI Author: Timothy Chen Closes #5763 from tnachen/mesos_cluster_page and squashes the following commits: 55f36eb [Timothy Chen] Add driver details page to Mesos cluster UI. --- .../spark/deploy/mesos/ui/DriverPage.scala | 180 ++++++++++++++++++ .../deploy/mesos/ui/MesosClusterPage.scala | 9 +- .../deploy/mesos/ui/MesosClusterUI.scala | 1 + .../deploy/rest/mesos/MesosRestServer.scala | 6 +- .../cluster/mesos/MesosClusterScheduler.scala | 33 +++- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- 6 files changed, 222 insertions(+), 11 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala new file mode 100644 index 0000000000000..be8560d10fc62 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState} +import org.apache.spark.ui.{UIUtils, WebUIPage} + + +private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") { + + override def render(request: HttpServletRequest): Seq[Node] = { + val driverId = request.getParameter("id") + require(driverId != null && driverId.nonEmpty, "Missing id parameter") + + val state = parent.scheduler.getDriverState(driverId) + if (state.isEmpty) { + val content = +
+

Cannot find driver {driverId}

+
+ return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + val driverState = state.get + val driverHeaders = Seq("Driver property", "Value") + val schedulerHeaders = Seq("Scheduler property", "Value") + val commandEnvHeaders = Seq("Command environment variable", "Value") + val launchedHeaders = Seq("Launched property", "Value") + val commandHeaders = Seq("Comamnd property", "Value") + val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") + val driverDescription = Iterable.apply(driverState.description) + val submissionState = Iterable.apply(driverState.submissionState) + val command = Iterable.apply(driverState.description.command) + val schedulerProperties = Iterable.apply(driverState.description.schedulerProperties) + val commandEnv = Iterable.apply(driverState.description.command.environment) + val driverTable = + UIUtils.listingTable(driverHeaders, driverRow, driverDescription) + val commandTable = + UIUtils.listingTable(commandHeaders, commandRow, command) + val commandEnvTable = + UIUtils.listingTable(commandEnvHeaders, propertiesRow, commandEnv) + val schedulerTable = + UIUtils.listingTable(schedulerHeaders, propertiesRow, schedulerProperties) + val launchedTable = + UIUtils.listingTable(launchedHeaders, launchedRow, submissionState) + val retryTable = + UIUtils.listingTable( + retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) + val content = +

Driver state information for driver id {driverId}

+ Back to Drivers +
+
+

Driver state: {driverState.state}

+

Driver properties

+ {driverTable} +

Driver command

+ {commandTable} +

Driver command environment

+ {commandEnvTable} +

Scheduler properties

+ {schedulerTable} +

Launched state

+ {launchedTable} +

Retry state

+ {retryTable} +
+
; + + UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + + private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { + submissionState.map { state => + + Mesos Slave ID + {state.slaveId.getValue} + + + Mesos Task ID + {state.taskId.getValue} + + + Launch Time + {state.startDate} + + + Finish Time + {state.finishDate.map(_.toString).getOrElse("")} + + + Last Task Status + {state.mesosTaskStatus.map(_.toString).getOrElse("")} + + }.getOrElse(Seq[Node]()) + } + + private def propertiesRow(properties: collection.Map[String, String]): Seq[Node] = { + properties.map { case (k, v) => + + {k}{v} + + }.toSeq + } + + private def commandRow(command: Command): Seq[Node] = { + + Main class{command.mainClass} + + + Arguments{command.arguments.mkString(" ")} + + + Class path entries{command.classPathEntries.mkString(" ")} + + + Java options{command.javaOpts.mkString((" "))} + + + Library path entries{command.libraryPathEntries.mkString((" "))} + + } + + private def driverRow(driver: MesosDriverDescription): Seq[Node] = { + + Name{driver.name} + + + Id{driver.submissionId} + + + Cores{driver.cores} + + + Memory{driver.mem} + + + Submitted{driver.submissionDate} + + + Supervise{driver.supervise} + + } + + private def retryRow(retryState: Option[MesosClusterRetryState]): Seq[Node] = { + retryState.map { state => + + + {state.lastFailureStatus} + + + {state.nextRetry} + + + {state.retries} + + + }.getOrElse(Seq[Node]()) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 7b2005e0f1237..7419fa9699648 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -56,8 +56,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId - {submission.submissionId} + {id} {submission.submissionDate} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -65,8 +66,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + val id = state.driverDescription.submissionId - {state.driverDescription.submissionId} + {id} {state.driverDescription.submissionDate} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} @@ -77,8 +79,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId - {submission.submissionId} + {id} {submission.submissionDate} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 4865d46dbc4ab..3f693545a0349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -39,6 +39,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) + attachPage(new DriverPage(this)) attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index fd17a980c9319..8198296eeb341 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -53,7 +53,7 @@ private[spark] class MesosRestServer( new MesosStatusRequestServlet(scheduler, masterConf) } -private[deploy] class MesosSubmitRequestServlet( +private[mesos] class MesosSubmitRequestServlet( scheduler: MesosClusterScheduler, conf: SparkConf) extends SubmitRequestServlet { @@ -139,7 +139,7 @@ private[deploy] class MesosSubmitRequestServlet( } } -private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) +private[mesos] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) extends KillRequestServlet { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = scheduler.killDriver(submissionId) @@ -148,7 +148,7 @@ private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, } } -private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) +private[mesos] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) extends StatusRequestServlet { protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val d = scheduler.getDriverStatus(submissionId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 0396e62be5309..06f0e2881c344 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -50,12 +50,13 @@ private[spark] class MesosClusterSubmissionState( val taskId: TaskID, val slaveId: SlaveID, var mesosTaskStatus: Option[TaskStatus], - var startDate: Date) + var startDate: Date, + var finishDate: Option[Date]) extends Serializable { def copy(): MesosClusterSubmissionState = { new MesosClusterSubmissionState( - driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + driverDescription, taskId, slaveId, mesosTaskStatus, startDate, finishDate) } } @@ -95,6 +96,14 @@ private[spark] class MesosClusterSchedulerState( val finishedDrivers: Iterable[MesosClusterSubmissionState], val pendingRetryDrivers: Iterable[MesosDriverDescription]) +/** + * The full state of a Mesos driver, that is being used to display driver information on the UI. + */ +private[spark] class MesosDriverState( + val state: String, + val description: MesosDriverDescription, + val submissionState: Option[MesosClusterSubmissionState] = None) + /** * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode * as Mesos tasks in a Mesos cluster. @@ -233,6 +242,22 @@ private[spark] class MesosClusterScheduler( s } + /** + * Gets the driver state to be displayed on the Web UI. + */ + def getDriverState(submissionId: String): Option[MesosDriverState] = { + stateLock.synchronized { + queuedDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("QUEUED", d)) + .orElse(launchedDrivers.get(submissionId) + .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d)))) + .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d)))) + .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("RETRYING", d))) + } + } + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity /** @@ -439,7 +464,7 @@ private[spark] class MesosClusterScheduler( logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, - None, new Date()) + None, new Date(), None) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -534,6 +559,7 @@ private[spark] class MesosClusterScheduler( // Check if the driver is supervise enabled and can be relaunched. if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState val (retries, waitTimeSec) = retryState .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } @@ -546,6 +572,7 @@ private[spark] class MesosClusterScheduler( pendingRetryDriversState.persist(taskId, newDriverDescription) } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) if (finishedDrivers.size >= retainedDrivers) { val toRemove = math.max(retainedDrivers / 10, 1) finishedDrivers.trimStart(toRemove) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8346a2407489f..86a7d0fb587e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend @@ -56,7 +56,7 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus - + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) @volatile var appId: String = _ From b4b43df8a338a30c0eadcf10cbe3ba203dc3f861 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 1 May 2015 18:38:20 -0700 Subject: [PATCH 21/91] [SPARK-6443] [SPARK SUBMIT] Could not submit app in standalone cluster mode when HA is enabled **3/26 update:** * Akka-based: Use an array of `ActorSelection` to represent multiple master. Add an `activeMasterActor` for query status of driver. And will add lost masters( including the standby one) to `lostMasters`. When size of `lostMasters` equals or greater than # of all masters, we should give an error that all masters are not avalible. * Rest-based: When all masters are not available(throw an exception), we use akka gateway to submit apps. I have tested simply on standalone HA cluster(with two masters alive and one alive/one dead), it worked. There might remains some issues on style or message print, but we can check the solution then fix them together. /cc srowen andrewor14 Author: WangTaoTheTonic Closes #5116 from WangTaoTheTonic/SPARK-6443 and squashes the following commits: 2a28aab [WangTaoTheTonic] based the newest change https://github.com/apache/spark/pull/5144 76fd411 [WangTaoTheTonic] rebase f4f972b [WangTaoTheTonic] rebase...again a41de0b [WangTaoTheTonic] rebase 220cb3c [WangTaoTheTonic] move connect exception inside 35119a0 [WangTaoTheTonic] style and compile issues 9d636be [WangTaoTheTonic] per Andrew's comments 979760c [WangTaoTheTonic] rebase e4f4ece [WangTaoTheTonic] fix failed test 5d23958 [WangTaoTheTonic] refact some duplicated code, style and comments 7a881b3 [WangTaoTheTonic] when one of masters is gone, we still can submit 2b011c9 [WangTaoTheTonic] fix broken tests 60d97a4 [WangTaoTheTonic] rebase fa1fa80 [WangTaoTheTonic] submit app to HA cluster in standalone cluster mode --- .../org/apache/spark/deploy/Client.scala | 73 ++++++--- .../apache/spark/deploy/ClientArguments.scala | 9 +- .../org/apache/spark/deploy/SparkSubmit.scala | 8 +- .../apache/spark/deploy/master/Master.scala | 24 ++- .../deploy/rest/RestSubmissionClient.scala | 152 +++++++++++++----- .../spark/deploy/worker/WorkerArguments.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 16 ++ .../rest/StandaloneRestSubmitSuite.scala | 40 ++--- 8 files changed, 229 insertions(+), 95 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c2c3e9a9e4827..848b62f9de71b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import scala.collection.mutable.HashSet import scala.concurrent._ import akka.actor._ @@ -31,21 +32,24 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} /** * Proxy that relays messages to the driver. + * + * We currently don't support retry if submission fails. In HA mode, client will submit request to + * all masters and see which one could handle it. */ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with ActorLogReceive with Logging { - var masterActor: ActorSelection = _ + private val masterActors = driverArgs.masters.map { m => + context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) + } + private val lostMasters = new HashSet[Address] + private var activeMasterActor: ActorSelection = null + val timeout = RpcUtils.askTimeout(conf) override def preStart(): Unit = { - masterActor = context.actorSelection( - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") - driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -79,11 +83,17 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.supervise, command) - masterActor ! RequestSubmitDriver(driverDescription) + // This assumes only one Master is active at a time + for (masterActor <- masterActors) { + masterActor ! RequestSubmitDriver(driverDescription) + } case "kill" => val driverId = driverArgs.driverId - masterActor ! RequestKillDriver(driverId) + // This assumes only one Master is active at a time + for (masterActor <- masterActors) { + masterActor ! RequestKillDriver(driverId) + } } } @@ -92,10 +102,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) + val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) .mapTo[DriverStatusResponse] val statusResponse = Await.result(statusFuture, timeout) - statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") @@ -122,20 +131,46 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) case SubmitDriverResponse(success, driverId, message) => println(message) - if (success) pollAndReportStatus(driverId.get) else System.exit(-1) + if (success) { + activeMasterActor = context.actorSelection(sender.path) + pollAndReportStatus(driverId.get) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } + case KillDriverResponse(driverId, success, message) => println(message) - if (success) pollAndReportStatus(driverId) else System.exit(-1) + if (success) { + activeMasterActor = context.actorSelection(sender.path) + pollAndReportStatus(driverId) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } case DisassociatedEvent(_, remoteAddress, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - System.exit(-1) + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterActors.size) { + println("No master is available, exiting.") + System.exit(-1) + } + } case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - println(s"Cause was: $cause") - System.exit(-1) + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master ($remoteAddress).") + println(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterActors.size) { + println("No master is available, exiting.") + System.exit(-1) + } + } } } @@ -163,7 +198,9 @@ object Client { "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem)) + for (m <- driverArgs.masters) { + Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) + } actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 5cbac787dceeb..316e2d59f01b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -22,8 +22,7 @@ import java.net.{URI, URISyntaxException} import scala.collection.mutable.ListBuffer import org.apache.log4j.Level - -import org.apache.spark.util.{IntParam, MemoryParam} +import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** * Command-line parser for the driver client. @@ -35,7 +34,7 @@ private[deploy] class ClientArguments(args: Array[String]) { var logLevel = Level.WARN // launch parameters - var master: String = "" + var masters: Array[String] = null var jarUrl: String = "" var mainClass: String = "" var supervise: Boolean = DEFAULT_SUPERVISE @@ -80,13 +79,13 @@ private[deploy] class ClientArguments(args: Array[String]) { } jarUrl = _jarUrl - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) mainClass = _mainClass _driverOptions ++= tail case "kill" :: _master :: _driverId :: tail => cmd = "kill" - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) driverId = _driverId case _ => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index af38bf80e4f0b..42b5d41b7b526 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -118,8 +118,8 @@ object SparkSubmit { * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ private def kill(args: SparkSubmitArguments): Unit = { - new RestSubmissionClient() - .killSubmission(args.master, args.submissionToKill) + new RestSubmissionClient(args.master) + .killSubmission(args.submissionToKill) } /** @@ -127,8 +127,8 @@ object SparkSubmit { * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new RestSubmissionClient() - .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) + new RestSubmissionClient(args.master) + .requestSubmissionStatus(args.submissionToRequestStatusFor) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index dc6077f3d132b..0fac3cdcf55e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -254,7 +254,8 @@ private[master] class Master( case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { - val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." sender ! SubmitDriverResponse(false, None, msg) } else { logInfo("Driver submitted " + description.command.mainClass) @@ -274,7 +275,8 @@ private[master] class Master( case RequestKillDriver(driverId) => { if (state != RecoveryState.ALIVE) { - val msg = s"Can only kill drivers in ALIVE state. Current state: $state." + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." sender ! KillDriverResponse(driverId, success = false, msg) } else { logInfo("Asked to kill driver " + driverId) @@ -305,12 +307,18 @@ private[master] class Master( } case RequestDriverStatus(driverId) => { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + sender ! DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + case None => + sender ! DriverStatusResponse(found = false, None, None, None, None) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 307cebfb4bd09..6078f50518ba4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -18,9 +18,10 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} -import java.net.{HttpURLConnection, SocketException, URL} +import java.net.{ConnectException, HttpURLConnection, SocketException, URL} import javax.servlet.http.HttpServletResponse +import scala.collection.mutable import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException @@ -51,57 +52,109 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[spark] class RestSubmissionClient extends Logging { +private[spark] class RestSubmissionClient(master: String) extends Logging { import RestSubmissionClient._ private val supportedMasterPrefixes = Seq("spark://", "mesos://") + private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + + // Set of masters that lost contact with us, used to keep track of + // whether there are masters still alive for us to communicate with + private val lostMasters = new mutable.HashSet[String] + /** * Submit an application specified by the parameters in the provided request. * * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - def createSubmission( - master: String, - request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + def createSubmission(request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") - validateMaster(master) - val url = getSubmitUrl(master) - val response = postJson(url, request.toJson) - response match { - case s: CreateSubmissionResponse => - reportSubmissionStatus(master, s) - handleRestResponse(s) - case unexpected => - handleUnexpectedRestResponse(unexpected) + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getSubmitUrl(m) + try { + response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + if (s.success) { + reportSubmissionStatus(s) + handleRestResponse(s) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } /** Request that the server kill the specified submission. */ - def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + def killSubmission(submissionId: String): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to kill submission $submissionId in $master.") - validateMaster(master) - val response = post(getKillUrl(master, submissionId)) - response match { - case k: KillSubmissionResponse => handleRestResponse(k) - case unexpected => handleUnexpectedRestResponse(unexpected) + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getKillUrl(m, submissionId) + try { + response = post(url) + response match { + case k: KillSubmissionResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } /** Request the status of a submission from the server. */ def requestSubmissionStatus( - master: String, submissionId: String, quiet: Boolean = false): SubmitRestProtocolResponse = { logInfo(s"Submitting a request for the status of submission $submissionId in $master.") - validateMaster(master) - val response = get(getStatusUrl(master, submissionId)) - response match { - case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) } - case unexpected => handleUnexpectedRestResponse(unexpected) + + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getStatusUrl(m, submissionId) + try { + response = get(url) + response match { + case s: SubmissionStatusResponse if s.success => + if (!quiet) { + handleRestResponse(s) + } + handled = true + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } @@ -148,11 +201,16 @@ private[spark] class RestSubmissionClient extends Logging { conn.setRequestProperty("Content-Type", "application/json") conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) - val out = new DataOutputStream(conn.getOutputStream) - Utils.tryWithSafeFinally { - out.write(json.getBytes(Charsets.UTF_8)) - } { - out.close() + try { + val out = new DataOutputStream(conn.getOutputStream) + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } + } catch { + case e: ConnectException => + throw new SubmitRestConnectionException("Connect Exception when connect to server", e) } readResponse(conn) } @@ -191,11 +249,9 @@ private[spark] class RestSubmissionClient extends Logging { } } catch { case unreachable @ (_: FileNotFoundException | _: SocketException) => - throw new SubmitRestConnectionException( - s"Unable to connect to server ${connection.getURL}", unreachable) + throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - throw new SubmitRestProtocolException( - "Malformed response received from server", malformed) + throw new SubmitRestProtocolException("Malformed response received from server", malformed) } } @@ -241,13 +297,12 @@ private[spark] class RestSubmissionClient extends Logging { /** Report the status of a newly created submission. */ private def reportSubmissionStatus( - master: String, submitResponse: CreateSubmissionResponse): Unit = { if (submitResponse.success) { val submissionId = submitResponse.submissionId if (submissionId != null) { logInfo(s"Submission successfully created as $submissionId. Polling submission state...") - pollSubmissionStatus(master, submissionId) + pollSubmissionStatus(submissionId) } else { // should never happen logError("Application successfully submitted, but submission ID was not provided!") @@ -262,9 +317,9 @@ private[spark] class RestSubmissionClient extends Logging { * Poll the status of the specified submission and log it. * This retries up to a fixed number of times before giving up. */ - private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + private def pollSubmissionStatus(submissionId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => - val response = requestSubmissionStatus(master, submissionId, quiet = true) + val response = requestSubmissionStatus(submissionId, quiet = true) val statusResponse = response match { case s: SubmissionStatusResponse => s case _ => return // unexpected type, let upstream caller handle it @@ -302,6 +357,21 @@ private[spark] class RestSubmissionClient extends Logging { private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") } + + /** + * When a connection exception is caught, return true if all masters are lost. + * Note that the heuristic used here does not take into account that masters + * can recover during the lifetime of this client. This assumption should be + * harmless because this client currently does not support retrying submission + * on failure yet (SPARK-6443). + */ + private def handleConnectionException(masterUrl: String): Boolean = { + if (!lostMasters.contains(masterUrl)) { + logWarning(s"Unable to connect to server ${masterUrl}.") + lostMasters += masterUrl + } + lostMasters.size >= masters.size + } } private[spark] object RestSubmissionClient { @@ -324,10 +394,10 @@ private[spark] object RestSubmissionClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new RestSubmissionClient + val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) - client.createSubmission(master, submitRequest) + client.createSubmission(submitRequest) } def main(args: Array[String]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 88f9d880ac209..9678631da9f6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -105,7 +105,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { if (masters != null) { // Two positional arguments were given printUsageAndExit(1) } - masters = value.stripPrefix("spark://").split(",").map("spark://" + _) + masters = Utils.parseStandaloneMasterUrls(value) parse(tail) case Nil => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 844f0cd22d95d..be4db02ab86d0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2159,6 +2159,22 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } + /** + * Split the comma delimited string of master URLs into a list. + * For instance, "spark://abc,def" becomes [spark://abc, spark://def]. + */ + def parseStandaloneMasterUrls(masterUrls: String): Array[String] = { + masterUrls.stripPrefix("spark://").split(",").map("spark://" + _) + } + + /** An identifier that backup masters use in their responses. */ + val BACKUP_STANDALONE_MASTER_PREFIX = "Current state is not alive" + + /** Return true if the response message is sent from a backup Master on standby. */ + def responseFromBackup(msg: String): Boolean = { + msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) + } + /** * Adds a shutdown hook with default priority. * diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 0a318a27ac212..f4d548d9e7720 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None private var server: Option[RestSubmissionServer] = None @@ -52,7 +51,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val appArgs = Array("one", "two", "three") val sparkProperties = Map("spark.app.name" -> "pi") val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") - val request = client.constructSubmitRequest( + val request = new RestSubmissionClient("spark://host:port").constructSubmitRequest( "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) assert(request.action === Utils.getFormattedClassName(request)) assert(request.clientSparkVersion === SPARK_VERSION) @@ -71,7 +70,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val request = constructSubmitRequest(masterUrl, appArgs) assert(request.appArgs === appArgs) assert(request.sparkProperties("spark.master") === masterUrl) - val response = client.createSubmission(masterUrl, request) + val response = new RestSubmissionClient(masterUrl).createSubmission(request) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -102,7 +101,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionId = "my-lyft-driver" val killMessage = "your driver is killed" val masterUrl = startDummyServer(killMessage = killMessage) - val response = client.killSubmission(masterUrl, submissionId) + val response = new RestSubmissionClient(masterUrl).killSubmission(submissionId) val killResponse = getKillResponse(response) assert(killResponse.action === Utils.getFormattedClassName(killResponse)) assert(killResponse.serverSparkVersion === SPARK_VERSION) @@ -116,7 +115,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionState = KILLED val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) - val response = client.requestSubmissionStatus(masterUrl, submissionId) + val response = new RestSubmissionClient(masterUrl).requestSubmissionStatus(submissionId) val statusResponse = getStatusResponse(response) assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) assert(statusResponse.serverSparkVersion === SPARK_VERSION) @@ -129,13 +128,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then kill") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) val submitResponse = getSubmitResponse(response1) assert(submitResponse.success) assert(submitResponse.submissionId != null) // kill submission that was just created val submissionId = submitResponse.submissionId - val response2 = client.killSubmission(masterUrl, submissionId) + val response2 = client.killSubmission(submissionId) val killResponse = getKillResponse(response2) assert(killResponse.success) assert(killResponse.submissionId === submissionId) @@ -144,13 +144,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then request status") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) val submitResponse = getSubmitResponse(response1) assert(submitResponse.success) assert(submitResponse.submissionId != null) // request status of submission that was just created val submissionId = submitResponse.submissionId - val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val response2 = client.requestSubmissionStatus(submissionId) val statusResponse = getStatusResponse(response2) assert(statusResponse.success) assert(statusResponse.submissionId === submissionId) @@ -160,8 +161,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then kill then request status") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) - val response2 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) + val response2 = client.createSubmission(request) val submitResponse1 = getSubmitResponse(response1) val submitResponse2 = getSubmitResponse(response2) assert(submitResponse1.success) @@ -171,13 +173,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionId1 = submitResponse1.submissionId val submissionId2 = submitResponse2.submissionId // kill only submission 1, but not submission 2 - val response3 = client.killSubmission(masterUrl, submissionId1) + val response3 = client.killSubmission(submissionId1) val killResponse = getKillResponse(response3) assert(killResponse.success) assert(killResponse.submissionId === submissionId1) // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still - val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) - val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val response4 = client.requestSubmissionStatus(submissionId1) + val response5 = client.requestSubmissionStatus(submissionId2) val statusResponse1 = getStatusResponse(response4) val statusResponse2 = getStatusResponse(response5) assert(statusResponse1.submissionId === submissionId1) @@ -189,13 +191,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("kill or request status before create") { val masterUrl = startSmartServer() val doesNotExist = "does-not-exist" + val client = new RestSubmissionClient(masterUrl) // kill a non-existent submission - val response1 = client.killSubmission(masterUrl, doesNotExist) + val response1 = client.killSubmission(doesNotExist) val killResponse = getKillResponse(response1) assert(!killResponse.success) assert(killResponse.submissionId === doesNotExist) // request status for a non-existent submission - val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val response2 = client.requestSubmissionStatus(doesNotExist) val statusResponse = getStatusResponse(response2) assert(!statusResponse.success) assert(statusResponse.submissionId === doesNotExist) @@ -339,6 +342,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() + val client = new RestSubmissionClient(masterUrl) val httpUrl = masterUrl.replace("spark://", "http://") val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" @@ -425,7 +429,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) - client.constructSubmitRequest( + new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) } @@ -492,7 +496,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { method: String, body: String = ""): (SubmitRestProtocolResponse, Int) = { val conn = sendHttpRequest(url, method, body) - (client.readResponse(conn), conn.getResponseCode) + (new RestSubmissionClient("spark://host:port").readResponse(conn), conn.getResponseCode) } } From 8f50a07d2188ccc5315d979755188b1e5d5b5471 Mon Sep 17 00:00:00 2001 From: Chris Heller Date: Fri, 1 May 2015 18:41:22 -0700 Subject: [PATCH 22/91] [SPARK-2691] [MESOS] Support for Mesos DockerInfo This patch adds partial support for running spark on mesos inside of a docker container. Only fine-grained mode is presently supported, and there is no checking done to ensure that the version of libmesos is recent enough to have a DockerInfo structure in the protobuf (other than pinning a mesos version in the pom.xml). Author: Chris Heller Closes #3074 from hellertime/SPARK-2691 and squashes the following commits: d504af6 [Chris Heller] Assist type inference f64885d [Chris Heller] Fix errant line length 17c41c0 [Chris Heller] Base Dockerfile on mesosphere/mesos image 8aebda4 [Chris Heller] Simplfy Docker image docs 1ae7f4f [Chris Heller] Style points 974bd56 [Chris Heller] Convert map to flatMap 5d8bdf7 [Chris Heller] Factor out the DockerInfo construction. 7b75a3d [Chris Heller] Align to styleguide 80108e7 [Chris Heller] Bend to the will of RAT ba77056 [Chris Heller] Explicit RAT exclude abda5e5 [Chris Heller] Wildcard .rat-excludes 2f2873c [Chris Heller] Exclude spark-mesos from RAT a589a5b [Chris Heller] Add example Dockerfile b6825ce [Chris Heller] Remove use of EasyMock eae1b86 [Chris Heller] Move properties under 'spark.mesos.' c184d00 [Chris Heller] Use map on Option to be consistent with non-coarse code fb9501a [Chris Heller] Bumped mesos version to current release fa11879 [Chris Heller] Add listenerBus to EasyMock 882151e [Chris Heller] Changes to scala style b22d42d [Chris Heller] Exclude template from RAT db536cf [Chris Heller] Remove unneeded mocks dea1bd5 [Chris Heller] Force default protocol 7dac042 [Chris Heller] Add test for DockerInfo 5456c0c [Chris Heller] Adjust syntax style 521c194 [Chris Heller] Adjust version info 6e38f70 [Chris Heller] Document Mesos Docker properties 29572ab [Chris Heller] Support all DockerInfo fields b8c0dea [Chris Heller] Support for mesos DockerInfo in coarse-mode. 482a9fd [Chris Heller] Support for mesos DockerInfo in fine-grained mode. --- .rat-excludes | 1 + conf/docker.properties.template | 3 + .../mesos/CoarseMesosSchedulerBackend.scala | 9 +- .../cluster/mesos/MesosSchedulerBackend.scala | 10 +- .../mesos/MesosSchedulerBackendUtil.scala | 142 ++++++++++++++++++ .../mesos/MesosSchedulerBackendSuite.scala | 46 ++++++ docker/spark-mesos/Dockerfile | 30 ++++ docs/running-on-mesos.md | 42 ++++++ pom.xml | 2 +- 9 files changed, 280 insertions(+), 5 deletions(-) create mode 100644 conf/docker.properties.template create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala create mode 100644 docker/spark-mesos/Dockerfile diff --git a/.rat-excludes b/.rat-excludes index 4468da19008bc..2238a5b68e359 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,6 +15,7 @@ TAGS RELEASE control docs +docker.properties.template fairscheduler.xml.template spark-defaults.conf.template log4j.properties diff --git a/conf/docker.properties.template b/conf/docker.properties.template new file mode 100644 index 0000000000000..26e3bfd9c5b9b --- /dev/null +++ b/conf/docker.properties.template @@ -0,0 +1,3 @@ +spark.mesos.executor.docker.image: +spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro +spark.mesos.executor.home: /opt/spark diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3412301e64fd7..dc59545b43314 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -196,9 +196,14 @@ private[spark] class CoarseMesosSchedulerBackend( .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", MemoryUtils.calculateTotalMemory(sc))) - .build() + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + } + d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task), filters) + Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) } else { // Filter it out d.launchTasks( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 86a7d0fb587e4..db0a080b3b0c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -124,13 +124,19 @@ private[spark] class MesosSchedulerBackend( Value.Scalar.newBuilder() .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() - MesosExecutorInfo.newBuilder() + val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) .addResources(cpus) .addResources(memory) - .build() + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) + } + + executorInfo.build() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala new file mode 100644 index 0000000000000..928c5cfed417a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.{ContainerInfo, Volume} +import org.apache.mesos.Protos.ContainerInfo.DockerInfo + +import org.apache.spark.{Logging, SparkConf} + +/** + * A collection of utility functions which can be used by both the + * MesosSchedulerBackend and the CoarseMesosSchedulerBackend. + */ +private[mesos] object MesosSchedulerBackendUtil extends Logging { + /** + * Parse a comma-delimited list of volume specs, each of which + * takes the form [host-dir:]container-dir[:rw|:ro]. + */ + def parseVolumesSpec(volumes: String): List[Volume] = { + volumes.split(",").map(_.split(":")).flatMap { spec => + val vol: Volume.Builder = Volume + .newBuilder() + .setMode(Volume.Mode.RW) + spec match { + case Array(container_path) => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "rw") => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setMode(Volume.Mode.RO)) + case Array(host_path, container_path) => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "rw") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path) + .setMode(Volume.Mode.RO)) + case spec => { + logWarning(s"Unable to parse volume specs: $volumes. " + + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Parse a comma-delimited list of port mapping specs, each of which + * takes the form host_port:container_port[:udp|:tcp] + * + * Note: + * the docker form is [ip:]host_port:container_port, but the DockerInfo + * message has no field for 'ip', and instead has a 'protocol' field. + * Docker itself only appears to support TCP, so this alternative form + * anticipates the expansion of the docker form to allow for a protocol + * and leaves open the chance for mesos to begin to accept an 'ip' field + */ + def parsePortMappingsSpec(portmaps: String): List[DockerInfo.PortMapping] = { + portmaps.split(",").map(_.split(":")).flatMap { spec: Array[String] => + val portmap: DockerInfo.PortMapping.Builder = DockerInfo.PortMapping + .newBuilder() + .setProtocol("tcp") + spec match { + case Array(host_port, container_port) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt)) + case Array(host_port, container_port, protocol) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt) + .setProtocol(protocol)) + case spec => { + logWarning(s"Unable to parse port mapping specs: $portmaps. " + + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Construct a DockerInfo structure and insert it into a ContainerInfo + */ + def addDockerInfo( + container: ContainerInfo.Builder, + image: String, + volumes: Option[List[Volume]] = None, + network: Option[ContainerInfo.DockerInfo.Network] = None, + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + + val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) + + network.foreach(docker.setNetwork) + portmaps.foreach(_.foreach(docker.addPortMappings)) + container.setType(ContainerInfo.Type.DOCKER) + container.setDocker(docker.build()) + volumes.foreach(_.foreach(container.addVolumes)) + } + + /** + * Setup a docker containerizer + */ + def setupContainerBuilderDockerInfo( + imageName: String, + conf: SparkConf, + builder: ContainerInfo.Builder): Unit = { + val volumes = conf + .getOption("spark.mesos.executor.docker.volumes") + .map(parseVolumesSpec) + val portmaps = conf + .getOption("spark.mesos.executor.docker.portmaps") + .map(parsePortMappingsSpec) + addDockerInfo( + builder, + imageName, + volumes = volumes, + portmaps = portmaps) + logDebug("setupContainerDockerInfo: using docker image: " + imageName) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index cdd7be0fbe5dd..ab863f3d8d672 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -73,6 +73,52 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } + test("spark docker properties correctly populate the DockerInfo message") { + val taskScheduler = mock[TaskSchedulerImpl] + + val conf = new SparkConf() + .set("spark.mesos.executor.docker.image", "spark/mock") + .set("spark.mesos.executor.docker.volumes", "/a,/b:/b,/c:/c:rw,/d:ro,/e:/e:ro") + .set("spark.mesos.executor.docker.portmaps", "80:8080,53:53:tcp") + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(conf) + when(sc.listenerBus).thenReturn(listenerBus) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val execInfo = backend.createExecutorInfo("mockExecutor") + assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) + val portmaps = execInfo.getContainer.getDocker.getPortMappingsList + assert(portmaps.get(0).getHostPort.equals(80)) + assert(portmaps.get(0).getContainerPort.equals(8080)) + assert(portmaps.get(0).getProtocol.equals("tcp")) + assert(portmaps.get(1).getHostPort.equals(53)) + assert(portmaps.get(1).getContainerPort.equals(53)) + assert(portmaps.get(1).getProtocol.equals("tcp")) + val volumes = execInfo.getContainer.getVolumesList + assert(volumes.get(0).getContainerPath.equals("/a")) + assert(volumes.get(0).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(1).getContainerPath.equals("/b")) + assert(volumes.get(1).getHostPath.equals("/b")) + assert(volumes.get(1).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(2).getContainerPath.equals("/c")) + assert(volumes.get(2).getHostPath.equals("/c")) + assert(volumes.get(2).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(3).getContainerPath.equals("/d")) + assert(volumes.get(3).getMode.equals(Volume.Mode.RO)) + assert(volumes.get(4).getContainerPath.equals("/e")) + assert(volumes.get(4).getHostPath.equals("/e")) + assert(volumes.get(4).getMode.equals(Volume.Mode.RO)) + } + test("mesos resource offers result in launching tasks") { def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile new file mode 100644 index 0000000000000..b90aef3655dee --- /dev/null +++ b/docker/spark-mesos/Dockerfile @@ -0,0 +1,30 @@ +# This is an example Dockerfile for creating a Spark image which can be +# references by the Spark property 'spark.mesos.executor.docker.image' +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM mesosphere/mesos:0.20.1 + +# Update the base ubuntu image with dependencies needed for Spark +RUN apt-get update && \ + apt-get install -y python libnss3 openjdk-7-jre-headless curl + +RUN mkdir /opt/spark && \ + curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + | tar -xzC /opt +ENV SPARK_HOME /opt/spark +ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8f53d8201a089..5f1d6daeb27f0 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,16 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +# Mesos Docker Support + +Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` +in your [SparkConf](configuration.html#spark-properties). + +The Docker image used must have an appropriate version of Spark already part of the image, or you can +have Mesos download Spark via the usual methods. + +Requires Mesos version 0.20.1 or later. + # Running Alongside Hadoop You can run Spark and Mesos alongside your existing Hadoop cluster by just launching them as a @@ -237,6 +247,38 @@ See the [configuration page](configuration.html) for information on Spark config The value can be a floating point number. + + spark.mesos.executor.docker.image + (none) + + Set the name of the docker image that the Spark executors will run in. The selected + image must have Spark installed, as well as a compatible version of the Mesos library. + The installed path of Spark in the image can be specified with spark.mesos.executor.home; + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + + + + spark.mesos.executor.docker.volumes + (none) + + Set the list of volumes which will be mounted into the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings following the form passed to docker run -v. That is they take the form: + +
[host_path:]container_path[:ro|:rw]
+ + + + spark.mesos.executor.docker.portmaps + (none) + + Set the list of incoming ports exposed by the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings which take the form: + +
host_port:container_port[:tcp|:udp]
+ + spark.mesos.executor.home driver side SPARK_HOME diff --git a/pom.xml b/pom.xml index c85c5feeaf383..4313f940036c8 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 1.6 spark 2.0.1 - 0.21.0 + 0.21.1 shaded-protobuf 1.7.10 1.2.17 From 38d4e9e446b425ca6a8fe8d8080f387b08683842 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 1 May 2015 19:01:46 -0700 Subject: [PATCH 23/91] [SPARK-6229] Add SASL encryption to network library. There are two main parts of this change: - Extending the bootstrap mechanism in the network library to add a server-side bootstrap (which works a little bit differently than the client-side bootstrap), and to allow the bootstraps to modify the underlying channel. - Use SASL to encrypt data going through the RPC channel. The second item requires some non-optimal code to be able to work around the fact that the outbound path in netty is not thread-safe, and ordering is very important when encryption is in the picture. A lot of the changes outside the network/common library are just to adjust to the changed API for initializing the RPC server. Author: Marcelo Vanzin Closes #5377 from vanzin/SPARK-6229 and squashes the following commits: ff01966 [Marcelo Vanzin] Use fancy new size config style. be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 7a2a805 [Marcelo Vanzin] Clean up some unneeded changes. 2f92237 [Marcelo Vanzin] Add comment. 67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code." 065f684 [Marcelo Vanzin] Add test to verify chunking. 3d1695d [Marcelo Vanzin] Minor cleanups. 73cff0e [Marcelo Vanzin] Skip bytes in decode path too. 318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code. 346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written. a4a5938 [Marcelo Vanzin] Review feedback. 4797519 [Marcelo Vanzin] Remove unused import. 9908ada [Marcelo Vanzin] Fix test, SASL backend disposal. 7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled. adb6f9d [Marcelo Vanzin] Review feedback. cf2a605 [Marcelo Vanzin] Clean up some code. 8584323 [Marcelo Vanzin] Fix a comment. e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server. dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive. b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes. b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages. 39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption. 351a86f [Marcelo Vanzin] Add SASL encryption to network library. fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it. --- .../org/apache/spark/SecurityManager.scala | 17 +- .../spark/deploy/ExternalShuffleService.scala | 17 +- .../netty/NettyBlockTransferService.scala | 22 +- .../spark/network/nio/ConnectionManager.scala | 4 +- .../apache/spark/storage/BlockManager.scala | 3 +- .../spark/network/TransportContext.java | 26 +- .../client/TransportClientBootstrap.java | 4 +- .../client/TransportClientFactory.java | 5 +- .../network/sasl/SaslClientBootstrap.java | 41 +- .../spark/network/sasl/SaslEncryption.java | 291 ++++++++++++++ .../network/sasl/SaslEncryptionBackend.java | 33 ++ .../spark/network/sasl/SaslRpcHandler.java | 56 ++- .../network/sasl/SaslServerBootstrap.java | 49 +++ .../spark/network/sasl/SparkSaslClient.java | 33 +- .../spark/network/sasl/SparkSaslServer.java | 49 ++- .../spark/network/server/TransportServer.java | 19 +- .../server/TransportServerBootstrap.java | 36 ++ .../util}/ByteArrayWritableChannel.java | 26 +- .../spark/network/util/TransportConf.java | 18 + .../apache/spark/network/ProtocolSuite.java | 1 + .../protocol/MessageWithHeaderSuite.java | 2 +- .../spark/network/sasl/SparkSaslSuite.java | 358 +++++++++++++++++- .../shuffle/ExternalShuffleClient.java | 11 +- .../network/sasl/SaslIntegrationSuite.java | 9 +- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 27 +- .../network/yarn/YarnShuffleService.java | 15 +- 27 files changed, 1070 insertions(+), 106 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java create mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java create mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java rename network/common/src/{test/java/org/apache/spark/network => main/java/org/apache/spark/network/util}/ByteArrayWritableChannel.java (70%) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 3653f724ba192..8aed1e20e0686 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -150,8 +150,13 @@ import org.apache.spark.util.Utils * authorization. If not filter is in place the user is generally null and no authorization * can take place. * - * Connection encryption (SSL) configuration is organized hierarchically. The user can configure - * the default SSL settings which will be used for all the supported communication protocols unless + * When authentication is being used, encryption can also be enabled by setting the option + * spark.authenticate.enableSaslEncryption to true. This is only supported by communication + * channels that use the network-common library, and can be used as an alternative to SSL in those + * cases. + * + * SSL can be used for encryption for certain communication channels. The user can configure the + * default SSL settings which will be used for all the supported communication protocols unless * they are overwritten by protocol specific settings. This way the user can easily provide the * common settings for all the protocols without disabling the ability to configure each one * individually. @@ -412,6 +417,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf) */ def isAuthenticationEnabled(): Boolean = authOn + /** + * Checks whether SASL encryption should be enabled. + * @return Whether to enable SASL encryption when connecting to services that support it. + */ + def isSaslEncryptionEnabled(): Boolean = { + sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false) + } + /** * Gets the user used for authenticating HTTP connections. * For now use a single hardcoded user. diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index cd16f992a3c0a..09973a0a2c998 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,10 +19,12 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch +import scala.collection.JavaConversions._ + import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.util.Utils @@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = new ExternalShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = { - val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler - new TransportContext(transportConf, handler) - } + private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) private var server: TransportServer = _ @@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + val bootstraps = + if (useSasl) { + Seq(new SaslServerBootstrap(transportConf, securityManager)) + } else { + Nil + } + server = transportContext.createServer(port, bootstraps) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 3f0950dae1f24..6181c0ee9fa2b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock @@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { - val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - if (!authEnabled) { - (nettyRpcHandler, None) - } else { - (new SaslRpcHandler(nettyRpcHandler, securityManager), - Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) - } + val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + var serverBootstrap: Option[TransportServerBootstrap] = None + var clientBootstrap: Option[TransportClientBootstrap] = None + if (authEnabled) { + serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, + securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(bootstrap.toList) - server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0)) + clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0), + serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 16e905982cf64..497871ed6d5e5 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -656,7 +656,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -800,7 +800,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 402ee1c7648c5..a46fecd2274ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -111,7 +111,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) } else { blockTransferService } diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 3fe69b1bd8851..b8d073fa16b4b 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -36,6 +36,7 @@ import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -82,13 +83,21 @@ public TransportClientFactory createClientFactory() { } /** Create a server which will attempt to bind to a specific port. */ - public TransportServer createServer(int port) { - return new TransportServer(this, port); + public TransportServer createServer(int port, List bootstraps) { + return new TransportServer(this, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ + public TransportServer createServer(List bootstraps) { + return createServer(0, bootstraps); + } + public TransportServer createServer() { - return new TransportServer(this, 0); + return createServer(0, Lists.newArrayList()); + } + + public TransportChannelHandler initializePipeline(SocketChannel channel) { + return initializePipeline(channel, rpcHandler); } /** @@ -96,13 +105,18 @@ public TransportServer createServer() { * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. * + * @param channel The channel to initialize. + * @param channelRpcHandler The RPC handler to use for the channel. + * * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public TransportChannelHandler initializePipeline(SocketChannel channel) { + public TransportChannelHandler initializePipeline( + SocketChannel channel, + RpcHandler channelRpcHandler) { try { - TransportChannelHandler channelHandler = createChannelHandler(channel); + TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) @@ -123,7 +137,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportChannelHandler createChannelHandler(Channel channel) { + private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java index 65e8020e34121..eaae2ee043c5a 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.channel.Channel; + /** * A bootstrap which is executed on a TransportClient before it is returned to the user. * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- @@ -28,5 +30,5 @@ */ public interface TransportClientBootstrap { /** Performs the bootstrapping operation, throwing an exception on failure. */ - public void doBootstrap(TransportClient client) throws RuntimeException; + void doBootstrap(TransportClient client, Channel channel) throws RuntimeException; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index d26b9b4d6055f..4952ffb44bb8b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -172,12 +172,14 @@ private TransportClient createClient(InetSocketAddress address) throws IOExcepti .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference clientRef = new AtomicReference(); + final AtomicReference channelRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); clientRef.set(clientHandler.getClient()); + channelRef.set(ch); } }); @@ -192,6 +194,7 @@ public void initChannel(SocketChannel ch) { } TransportClient client = clientRef.get(); + Channel channel = channelRef.get(); assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. @@ -199,7 +202,7 @@ public void initChannel(SocketChannel ch) { logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { - clientBootstrap.doBootstrap(client); + clientBootstrap.doBootstrap(client, channel); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 33aa1344345ff..185ba2ef3bb1f 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,8 +17,12 @@ package org.apache.spark.network.sasl; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,14 +37,24 @@ public class SaslClientBootstrap implements TransportClientBootstrap { private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + private final boolean encrypt; private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this(conf, appId, secretKeyHolder, false); + } + + public SaslClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder, + boolean encrypt) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; + this.encrypt = encrypt; } /** @@ -49,8 +63,8 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec * due to mismatch. */ @Override - public void doBootstrap(TransportClient client) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + public void doBootstrap(TransportClient client, Channel channel) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); try { byte[] payload = saslClient.firstToken(); @@ -62,13 +76,26 @@ public void doBootstrap(TransportClient client) { byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); payload = saslClient.response(response); } + + if (encrypt) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { + throw new RuntimeException( + new SaslException("Encryption requests by negotiated non-encrypted connection.")); + } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + saslClient = null; + logger.debug("Channel {} configured for SASL encryption.", client); + } } finally { - try { - // Once authentication is complete, the server will trust all remaining communication. - saslClient.dispose(); - } catch (RuntimeException e) { - logger.error("Error while disposing SASL client", e); + if (saslClient != null) { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } } } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java new file mode 100644 index 0000000000000..127335e4d35fb --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.FileRegion; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.NettyUtils; + +/** + * Provides SASL-based encription for transport channels. The single method exposed by this + * class installs the needed channel handlers on a connected channel. + */ +class SaslEncryption { + + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "saslEncryption"; + + /** + * Adds channel handlers that perform encryption / decryption of data using SASL. + * + * @param channel The channel. + * @param backend The SASL backend. + * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control + * memory usage. + */ + static void addToChannel( + Channel channel, + SaslEncryptionBackend backend, + int maxOutboundBlockSize) { + channel.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) + .addFirst("saslDecryption", new DecryptionHandler(backend)) + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + } + + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + + private final int maxOutboundBlockSize; + private final SaslEncryptionBackend backend; + + EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) { + this.backend = backend; + this.maxOutboundBlockSize = maxOutboundBlockSize; + } + + /** + * Wrap the incoming message in an implementation that will perform encryption lazily. This is + * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in + * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it + * does not guarantee any ordering. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + + ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + backend.dispose(); + } finally { + super.handlerRemoved(ctx); + } + } + + } + + private static class DecryptionHandler extends MessageToMessageDecoder { + + private final SaslEncryptionBackend backend; + + DecryptionHandler(SaslEncryptionBackend backend) { + this.backend = backend; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) + throws Exception { + + byte[] data; + int offset; + int length = msg.readableBytes(); + if (msg.hasArray()) { + data = msg.array(); + offset = msg.arrayOffset(); + msg.skipBytes(length); + } else { + data = new byte[length]; + msg.readBytes(data); + offset = 0; + } + + out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length))); + } + + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + + private final SaslEncryptionBackend backend; + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + + /** + * A channel used to buffer input data for encryption. The channel has an upper size bound + * so that if the input is larger than the allowed buffer, it will be broken into multiple + * chunks. + */ + private final ByteArrayWritableChannel byteChannel; + + private ByteBuf currentHeader; + private ByteBuffer currentChunk; + private long currentChunkSize; + private long currentReportedBytes; + private long unencryptedChunkSize; + private long transferred; + + EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.backend = backend; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } + + /** + * Returns the size of the original (unencrypted) message. + * + * This makes assumptions about how netty treats FileRegion instances, because there's no way + * to know beforehand what will be the size of the encrypted message. Namely, it assumes + * that netty will try to transfer data from this message while + * transfered() < count(). So these two methods return, technically, wrong data, + * but netty doesn't know better. + */ + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + /** + * Returns an approximation of the amount of data transferred. See {@link #count()}. + */ + @Override + public long transfered() { + return transferred; + } + + /** + * Transfers data from the original message to the channel, encrypting it in the process. + * + * This method also breaks down the original message into smaller chunks when needed. This + * is done to keep memory usage under control. This avoids having to copy the whole message + * data into memory at once, and can avoid ballooning memory usage when transferring large + * messages such as shuffle blocks. + * + * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * until a whole chunk has been written. This is done because the code can't use the actual + * number of bytes written to the channel as the transferred count (see {@link #count()}). + * Instead, once an encrypted chunk is written to the output (including its header), the + * size of the original block will be added to the {@link #transfered()} amount. + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) + throws IOException { + + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + long reportedWritten = 0L; + long actuallyWritten = 0L; + do { + if (currentChunk == null) { + nextChunk(); + } + + if (currentHeader.readableBytes() > 0) { + int bytesWritten = target.write(currentHeader.nioBuffer()); + currentHeader.skipBytes(bytesWritten); + actuallyWritten += bytesWritten; + if (currentHeader.readableBytes() > 0) { + // Break out of loop if there are still header bytes left to write. + break; + } + } + + actuallyWritten += target.write(currentChunk); + if (!currentChunk.hasRemaining()) { + // Only update the count of written bytes once a full chunk has been written. + // See method javadoc. + long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes; + reportedWritten += chunkBytesRemaining; + transferred += chunkBytesRemaining; + currentHeader.release(); + currentHeader = null; + currentChunk = null; + currentChunkSize = 0; + currentReportedBytes = 0; + } + } while (currentChunk == null && transfered() + reportedWritten < count()); + + // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, + // we return 1 until we can (i.e. until the reported count would actually match the size + // of the current chunk), at which point we resort to returning 0 so that the counts still + // match, at the cost of some performance. That situation should be rare, though. + if (reportedWritten != 0L) { + return reportedWritten; + } + + if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) { + transferred += 1L; + currentReportedBytes += 1L; + return 1L; + } + + return 0L; + } + + private void nextChunk() throws IOException { + byteChannel.reset(); + if (isByteBuf) { + int copied = byteChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteChannel, region.transfered()); + } + + byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); + this.currentChunk = ByteBuffer.wrap(encrypted); + this.currentChunkSize = encrypted.length; + this.currentHeader = Unpooled.copyLong(8 + currentChunkSize); + this.unencryptedChunkSize = byteChannel.length(); + } + + @Override + protected void deallocate() { + if (currentHeader != null) { + currentHeader.release(); + } + if (buf != null) { + buf.release(); + } + if (region != null) { + region.release(); + } + } + + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java new file mode 100644 index 0000000000000..89b78bc7e1df1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.sasl.SaslException; + +interface SaslEncryptionBackend { + + /** Disposes of resources used by the backend. */ + void dispose(); + + /** Encrypt data. */ + byte[] wrap(byte[] data, int offset, int len) throws SaslException; + + /** Decrypt data. */ + byte[] unwrap(byte[] data, int offset, int len) throws SaslException; + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 026cbd260d16c..be6165caf3c74 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import java.util.concurrent.ConcurrentMap; +import javax.security.sasl.Sasl; -import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +28,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; /** * RPC Handler which performs SASL authentication before delegating to a child RPC handler. @@ -37,8 +38,14 @@ * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -public class SaslRpcHandler extends RpcHandler { - private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); +class SaslRpcHandler extends RpcHandler { + private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; /** RpcHandler we will delegate to for authenticated connections. */ private final RpcHandler delegate; @@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler { /** Class which provides secret keys which are shared by server and client on a per-app basis. */ private final SecretKeyHolder secretKeyHolder; - /** Maps each channel to its SASL authentication state. */ - private final ConcurrentMap channelAuthenticationMap; + private SparkSaslServer saslServer; + private boolean isComplete; - public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + SaslRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; this.delegate = delegate; this.secretKeyHolder = secretKeyHolder; - this.channelAuthenticationMap = Maps.newConcurrentMap(); + this.saslServer = null; + this.isComplete = false; } @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - SparkSaslServer saslServer = channelAuthenticationMap.get(client); - if (saslServer != null && saslServer.isComplete()) { + if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; @@ -68,15 +81,30 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback if (saslServer == null) { // First message in the handshake, setup the necessary state. - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); - channelAuthenticationMap.put(client, saslServer); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); } byte[] response = saslServer.response(saslMessage.payload); + callback.onSuccess(response); + + // Setup encryption after the SASL response is sent, otherwise the client can't parse the + // response. It's ok to change the channel pipeline here since we are processing an incoming + // message, so the pipeline is busy and no new incoming messages will be fed to it before this + // method returns. This assumes that the code ensures, through other means, that no outbound + // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { logger.debug("SASL authentication successful for channel {}", client); + isComplete = true; + if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + saslServer = null; + } else { + saslServer.dispose(); + saslServer = null; + } } - callback.onSuccess(response); } @Override @@ -86,9 +114,9 @@ public StreamManager getStreamManager() { @Override public void connectionTerminated(TransportClient client) { - SparkSaslServer saslServer = channelAuthenticationMap.remove(client); if (saslServer != null) { saslServer.dispose(); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java new file mode 100644 index 0000000000000..f2f983856f444 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.channel.Channel; + +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public class SaslServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL + * negotiation. + */ + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 9abad1f30a259..94685e91b862e 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.util.Map; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; @@ -27,9 +29,9 @@ import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import java.io.IOException; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,19 +42,25 @@ * initial state to the "authenticated" state. This client initializes the protocol via a * firstToken, which is then followed by a set of challenges and responses. */ -public class SparkSaslClient { +public class SparkSaslClient implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; + private final String expectedQop; private SaslClient saslClient; - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; + + Map saslProps = ImmutableMap.builder() + .put(Sasl.QOP, expectedQop) + .build(); try { this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, - SASL_PROPS, new ClientCallbackHandler()); + saslProps, new ClientCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); } @@ -76,6 +84,11 @@ public synchronized boolean isComplete() { return saslClient != null && saslClient.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslClient.getNegotiatedProperty(name); + } + /** * Respond to server's SASL token. * @param token contains server's SASL token @@ -93,6 +106,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslClient might be using. */ + @Override public synchronized void dispose() { if (saslClient != null) { try { @@ -134,4 +148,15 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback } } } + + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.unwrap(data, offset, len); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e87b17ead1e1a..431cb67a2ae0b 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -44,7 +44,7 @@ * initial state to the "authenticated" state. (It is not a server in the sense of accepting * connections on some socket.) */ -public class SparkSaslServer { +public class SparkSaslServer implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); /** @@ -60,26 +60,37 @@ public class SparkSaslServer { static final String DIGEST = "DIGEST-MD5"; /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. + * Quality of protection value that includes encryption. */ - static final Map SASL_PROPS = ImmutableMap.builder() - .put(Sasl.QOP, "auth") - .put(Sasl.SERVER_AUTH, "true") - .build(); + static final String QOP_AUTH_CONF = "auth-conf"; + + /** + * Quality of protection value that does not include encryption. + */ + static final String QOP_AUTH = "auth"; /** Identifier for a certain secret key within the secretKeyHolder. */ private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; private SaslServer saslServer; - public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslServer( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + + // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption + // is listed first since it's preferred over the non-encrypted one (if the client also + // lists both in the request). + String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH); + Map saslProps = ImmutableMap.builder() + .put(Sasl.SERVER_AUTH, "true") + .put(Sasl.QOP, qop) + .build(); try { - this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); @@ -93,6 +104,11 @@ public synchronized boolean isComplete() { return saslServer != null && saslServer.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslServer.getNegotiatedProperty(name); + } + /** * Used to respond to server SASL tokens. * @param token Server's SASL token @@ -110,6 +126,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslServer might be using. */ + @Override public synchronized void dispose() { if (saslServer != null) { try { @@ -122,6 +139,16 @@ public synchronized void dispose() { } } + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.unwrap(data, offset, len); + } + /** * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index b7ce8541e565e..941ef95772e16 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -19,8 +19,11 @@ import java.io.Closeable; import java.net.InetSocketAddress; +import java.util.List; import java.util.concurrent.TimeUnit; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelFuture; @@ -44,15 +47,23 @@ public class TransportServer implements Closeable { private final TransportContext context; private final TransportConf conf; + private final RpcHandler appRpcHandler; + private final List bootstraps; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port = -1; /** Creates a TransportServer that binds to the given port, or to any available if 0. */ - public TransportServer(TransportContext context, int portToBind) { + public TransportServer( + TransportContext context, + int portToBind, + RpcHandler appRpcHandler, + List bootstraps) { this.context = context; this.conf = context.getConf(); + this.appRpcHandler = appRpcHandler; + this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); init(portToBind); } @@ -95,7 +106,11 @@ private void init(int portToBind) { bootstrap.childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel ch) throws Exception { - context.initializePipeline(ch); + RpcHandler rpcHandler = appRpcHandler; + for (TransportServerBootstrap bootstrap : bootstraps) { + rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); + } + context.initializePipeline(ch, rpcHandler); } }); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java new file mode 100644 index 0000000000000..05803ab1bb059 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.Channel; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public interface TransportServerBootstrap { + /** + * Customizes the channel to include new features, if needed. + * + * @param channel The connected channel opened by the client. + * @param rpcHandler The RPC handler for the server. + * @return The RPC handler to use for the channel. + */ + RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler); +} diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java similarity index 70% rename from network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java rename to network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java index b525ed69fc9fb..b1415720045e2 100644 --- a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java @@ -15,11 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.util; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; +/** + * A writable channel that stores the written data in a byte array in memory. + */ public class ByteArrayWritableChannel implements WritableByteChannel { private final byte[] data; @@ -27,19 +30,30 @@ public class ByteArrayWritableChannel implements WritableByteChannel { public ByteArrayWritableChannel(int size) { this.data = new byte[size]; - this.offset = 0; } public byte[] getData() { return data; } + public int length() { + return offset; + } + + /** Resets the channel so that writing to it will overwrite the existing buffer. */ + public void reset() { + offset = 0; + } + + /** + * Reads from the given buffer into the internal byte array. + */ @Override public int write(ByteBuffer src) { - int available = src.remaining(); - src.get(data, offset, available); - offset += available; - return available; + int toTransfer = Math.min(src.remaining(), data.length - offset); + src.get(data, offset, toTransfer); + offset += toTransfer; + return toTransfer; } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 0aef7f1987315..3b2eff377955a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import com.google.common.primitives.Ints; + /** * A central location that tracks all the settings we expose to users. */ @@ -112,4 +114,20 @@ public boolean lazyFileDescriptor() { public int portMaxRetries() { return conf.getInt("spark.port.maxRetries", 16); } + + /** + * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + */ + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 860dd6d9b3915..d500bc3c98a78 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ff985096d72d5..6c98e733b462f 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.*; -import org.apache.spark.network.ByteArrayWritableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; public class MessageWithHeaderSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 23b4e06f064e1..be6632bb8cf49 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -17,12 +17,47 @@ package org.apache.spark.network.sasl; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.security.sasl.SaslException; + +import com.google.common.collect.Lists; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; /** * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. @@ -44,8 +79,8 @@ public String getSecretKey(String appId) { @Test public void testMatching() { - SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false); assertFalse(client.isComplete()); assertFalse(server.isComplete()); @@ -64,11 +99,10 @@ public void testMatching() { assertFalse(client.isComplete()); } - @Test public void testNonMatching() { - SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false); assertFalse(client.isComplete()); assertFalse(server.isComplete()); @@ -86,4 +120,312 @@ public void testNonMatching() { assertFalse(server.isComplete()); } } + + @Test + public void testSaslAuthentication() throws Exception { + testBasicSasl(false); + } + + @Test + public void testSaslEncryption() throws Exception { + testBasicSasl(true); + } + + private void testBasicSasl(boolean encrypt) throws Exception { + RpcHandler rpcHandler = mock(RpcHandler.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + byte[] message = (byte[]) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", new String(message, UTF_8)); + cb.onSuccess("Pong".getBytes(UTF_8)); + return null; + } + }) + .when(rpcHandler) + .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + try { + byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", new String(response, UTF_8)); + } finally { + ctx.close(); + } + } + + @Test + public void testEncryptedMessage() throws Exception { + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + byte[] data = new byte[1024]; + new Random().nextBytes(data); + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + ByteBuf msg = Unpooled.buffer(); + try { + msg.writeBytes(data); + + // Create a channel with a really small buffer compared to the data. This means that on each + // call, the outbound data will not be fully written, so the write() method should return a + // dummy count to keep the channel alive when possible. + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32); + + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg, 1024); + long count = emsg.transferTo(channel, emsg.transfered()); + assertTrue(count < data.length); + assertTrue(count > 0); + + // Here, the output buffer is full so nothing should be transferred. + assertEquals(0, emsg.transferTo(channel, emsg.transfered())); + + // Now there's room in the buffer, but not enough to transfer all the remaining data, + // so the dummy count should be returned. + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + + // Eventually, the whole message should be transferred. + for (int i = 0; i < data.length / 32 - 2; i++) { + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + } + + channel.reset(); + count = emsg.transferTo(channel, emsg.transfered()); + assertTrue("Unexpected count: " + count, count > 1 && count < data.length); + assertEquals(data.length, emsg.transfered()); + } finally { + msg.release(); + } + } + + @Test + public void testEncryptedMessageChunking() throws Exception { + File file = File.createTempFile("sasltest", ".txt"); + try { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + // It doesn't really matter what we return here, as long as it's not null. + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length()); + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + while (emsg.transfered() < emsg.count()) { + channel.reset(); + emsg.transferTo(channel, emsg.transfered()); + } + + verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt()); + } finally { + file.delete(); + } + } + + @Test + public void testFileRegionEncryption() throws Exception { + final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; + System.setProperty(blockSizeConf, "1k"); + + final AtomicReference response = new AtomicReference(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(conf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + System.clearProperty(blockSizeConf); + } + } + + @Test + public void testServerAlwaysEncrypt() throws Exception { + final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; + System.setProperty(alwaysEncryptConfName, "true"); + + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + fail("Should have failed to connect without encryption."); + } catch (Exception e) { + assertTrue(e.getCause() instanceof SaslException); + } finally { + if (ctx != null) { + ctx.close(); + } + System.clearProperty(alwaysEncryptConfName); + } + } + + @Test + public void testDataEncryptionIsActuallyEnabled() throws Exception { + // This test sets up an encrypted connection but then, using a client bootstrap, removes + // the encryption handler from the client side. This should cause the server to not be + // able to understand RPCs sent to it and thus close the connection. + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); + fail("Should have failed to send RPC to server."); + } catch (Exception e) { + assertFalse(e.getCause() instanceof TimeoutException); + } finally { + if (ctx != null) { + ctx.close(); + } + } + } + + private static class SaslTestCtx { + + final TransportClient client; + final TransportServer server; + + private final boolean encrypt; + private final boolean disableClientEncryption; + private final EncryptionCheckerBootstrap checker; + + SaslTestCtx( + RpcHandler rpcHandler, + boolean encrypt, + boolean disableClientEncryption) + throws Exception { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn("user"); + when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); + + TransportContext ctx = new TransportContext(conf, rpcHandler); + + this.checker = new EncryptionCheckerBootstrap(); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), + checker)); + + try { + List clientBootstraps = Lists.newArrayList(); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + if (disableClientEncryption) { + clientBootstraps.add(new EncryptionDisablerBootstrap()); + } + + this.client = ctx.createClientFactory(clientBootstraps) + .createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + close(); + throw e; + } + + this.encrypt = encrypt; + this.disableClientEncryption = disableClientEncryption; + } + + void close() { + if (!disableClientEncryption) { + assertEquals(encrypt, checker.foundEncryptionHandler); + } + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + + } + + private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter + implements TransportServerBootstrap { + + boolean foundEncryptionHandler; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (!foundEncryptionHandler) { + foundEncryptionHandler = + ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + } + ctx.write(msg, promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + } + + @Override + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + channel.pipeline().addFirst("encryptionChecker", this); + return rpcHandler; + } + + } + + private static class EncryptionDisablerBootstrap implements TransportClientBootstrap { + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME); + } + + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6e8018b723dc6..612bce571a493 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient { private final TransportConf conf; private final boolean saslEnabled; + private final boolean saslEncryptionEnabled; private final SecretKeyHolder secretKeyHolder; private TransportClientFactory clientFactory; @@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled) { + boolean saslEnabled, + boolean saslEncryptionEnabled) { + Preconditions.checkArgument( + !saslEncryptionEnabled || saslEnabled, + "SASL encryption can only be enabled if SASL is also enabled."); this.conf = conf; this.secretKeyHolder = secretKeyHolder; this.saslEnabled = saslEnabled; + this.saslEncryptionEnabled = saslEncryptionEnabled; } @Override @@ -70,7 +77,7 @@ public void init(String appId) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); List bootstraps = Lists.newArrayList(); if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); } clientFactory = context.createClientFactory(bootstraps); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index d25283e46ef96..382f613ecbb1b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.util.Arrays; import com.google.common.collect.Lists; import org.junit.After; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -72,10 +74,11 @@ public String getSecretKey(String appId) { @BeforeClass public static void beforeAll() throws IOException { SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); - SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); conf = new TransportConf(new SystemPropertyConfigProvider()); - context = new TransportContext(conf, handler); - server = context.createServer(); + context = new TransportContext(conf, new TestRpcHandler()); + + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + server = context.createServer(Arrays.asList(bootstrap)); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 02c10bcb7b261..39aa49911d9cb 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -136,7 +136,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -274,7 +274,7 @@ public void testFetchNoServer() throws Exception { private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 759a12910c94d..d4ec1956c1e29 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.util.Arrays; import org.junit.After; import org.junit.Before; @@ -27,10 +28,11 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -42,10 +44,10 @@ public class ExternalShuffleSecuritySuite { @Before public void beforeEach() { - RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), - new TestSecretKeyHolder("my-app-id", "secret")); - TransportContext context = new TransportContext(conf, handler); - this.server = context.createServer(); + TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf)); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, + new TestSecretKeyHolder("my-app-id", "secret")); + this.server = context.createServer(Arrays.asList(bootstrap)); } @After @@ -58,13 +60,13 @@ public void afterEach() { @Test public void testValid() throws IOException { - validate("my-app-id", "secret"); + validate("my-app-id", "secret", false); } @Test public void testBadAppId() { try { - validate("wrong-app-id", "secret"); + validate("wrong-app-id", "secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); } @@ -73,16 +75,21 @@ public void testBadAppId() { @Test public void testBadSecret() { try { - validate("my-app-id", "bad-secret"); + validate("my-app-id", "bad-secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } } + @Test + public void testEncryption() throws IOException { + validate("my-app-id", "secret", true); + } + /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) throws IOException { + private void validate(String appId, String secretKey, boolean encrypt) throws IOException { ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 63b21222e7b77..463f99ef3352d 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,9 +17,10 @@ package org.apache.spark.network.yarn; -import java.lang.Override; import java.nio.ByteBuffer; +import java.util.List; +import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; @@ -32,10 +33,11 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.yarn.util.HadoopConfigProvider; @@ -103,16 +105,17 @@ protected void serviceInit(Configuration conf) { // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); blockHandler = new ExternalShuffleBlockHandler(transportConf); - RpcHandler rpcHandler = blockHandler; + + List bootstraps = Lists.newArrayList(); if (authEnabled) { secretManager = new ShuffleSecretManager(); - rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); } int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, rpcHandler); - shuffleServer = transportContext.createServer(port); + TransportContext transportContext = new TransportContext(transportConf, blockHandler); + shuffleServer = transportContext.createServer(port, bootstraps); String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + "Authentication is {}.", port, authEnabledString); From b79aeb95b45ab4ae811039d452cf028d7b844132 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 1 May 2015 21:23:42 -0700 Subject: [PATCH 24/91] [SPARK-7317] [Shuffle] Expose shuffle handle Details in JIRA, in a nut-shell, all machinary for custom RDD's to leverage spark shuffle directly (without exposing impl details of shuffle) exists - except for this small piece. Exposing this will allow for custom dependencies to get a handle to ShuffleHandle - which they can then leverage on reduce side. Author: Mridul Muralidharan Closes #5857 from mridulm/expose_shuffle_handle and squashes the following commits: d8b6bd4 [Mridul Muralidharan] Expose ShuffleHandle --- .../main/scala/org/apache/spark/shuffle/ShuffleHandle.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala index 13c7115f88afa..e04c97fe61894 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -17,9 +17,12 @@ package org.apache.spark.shuffle +import org.apache.spark.annotation.DeveloperApi + /** * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. * * @param shuffleId ID of the shuffle */ -private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} +@DeveloperApi +abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} From 2e0f3579f1fa7139c2e79bde656cbac049abbc33 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 23:43:24 -0700 Subject: [PATCH 25/91] [SPARK-7242] added python api for freqItems in DataFrames The python api for DataFrame's plus addressed your comments from previous PR. rxin Author: Burak Yavuz Closes #5859 from brkyvz/df-freq-py2 and squashes the following commits: f9aa9ce [Burak Yavuz] addressed comments v0.1 4b25056 [Burak Yavuz] added python api for freqItems --- python/pyspark/sql/dataframe.py | 25 +++++++++++++++++++ python/pyspark/sql/tests.py | 7 ++++++ .../spark/sql/DataFrameStatFunctions.scala | 9 ++++--- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5ff49cac5522b..e9fd17ed4ce94 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -889,6 +889,26 @@ def cov(self, col1, col2): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) + def freqItems(self, cols, support=None): + """ + Finding frequent items for columns, possibly with false positives. Using the + frequent element count algorithm described in + "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". + :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + + :param cols: Names of the columns to calculate frequent items for as a list or tuple of + strings. + :param support: The frequency with which to consider an item 'frequent'. Default is 1%. + The support must be greater than 1e-4. + """ + if isinstance(cols, tuple): + cols = list(cols) + if not isinstance(cols, list): + raise ValueError("cols must be a list or tuple of column names as strings.") + if not support: + support = 0.01 + return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1344,6 +1364,11 @@ def cov(self, col1, col2): cov.__doc__ = DataFrame.cov.__doc__ + def freqItems(self, cols, support=None): + return self.df.freqItems(cols, support) + + freqItems.__doc__ = DataFrame.freqItems.__doc__ + def _test(): import doctest diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 44c8b6a1aac13..613efc0ac029d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -375,6 +375,13 @@ def test_column_select(self): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + def test_freqItems(self): + vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] + df = self.sc.parallelize(vals).toDF() + items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] + self.assertTrue(1 in items[0]) + self.assertTrue(-2.0 in items[1]) + def test_aggregator(self): df = self.df g = df.groupBy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 23652aeb7c7bc..e8fa82947759b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,10 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { } /** - * Runs `freqItems` with a default `support` of 1%. + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. + * Uses a `default` support of 1%. * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. @@ -55,14 +58,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Python friendly implementation for `freqItems` */ - def freqItems(cols: List[String], support: Double): DataFrame = { + def freqItems(cols: Seq[String], support: Double): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, support) } /** * Python friendly implementation for `freqItems` with a default `support` of 1%. */ - def freqItems(cols: List[String]): DataFrame = { + def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } From 7394e7adeb03df159978f1d10061d9ec6a913968 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 1 May 2015 23:57:58 -0700 Subject: [PATCH 26/91] [SPARK-7120] [SPARK-7121] Closure cleaner nesting + documentation + tests Note: ~600 lines of this is test code, and ~100 lines documentation. **[SPARK-7121]** ClosureCleaner does not handle nested closures properly. For instance, in SparkContext, I tried to do the following: ``` def scope[T](body: => T): T = body // no-op def myCoolMethod(path: String): RDD[String] = scope { parallelize(1 to 10).map { _ => path } } ``` and I got an exception complaining that SparkContext is not serializable. The issue here is that the inner closure is getting its path from the outer closure (the scope), but the outer closure references the SparkContext object itself to get the `parallelize` method. Note, however, that the inner closure doesn't actually need the SparkContext; it just needs a field from the outer closure. If we modify ClosureCleaner to clean the outer closure recursively using only the fields accessed by the inner closure, then we can serialize the inner closure. **[SPARK-7120]** Also, the other thing is that this file is one of the least understood, partly because it is very low level and is written a long time ago. This patch attempts to change that by adding the missing documentation. This is blocking my effort on a separate task #5729. Author: Andrew Or Closes #5685 from andrewor14/closure-cleaner and squashes the following commits: cd46230 [Andrew Or] Revert a small change that affected streaming 0bbe77f [Andrew Or] Fix style ea874bc [Andrew Or] Fix tests 26c5072 [Andrew Or] Address comments 16fbcfd [Andrew Or] Merge branch 'master' of github.com:apache/spark into closure-cleaner 26c7aba [Andrew Or] Revert "In sc.runJob, actually clean the inner closure" 6f75784 [Andrew Or] Revert "Guard against NPE if CC is used outside of an application" e909a42 [Andrew Or] Guard against NPE if CC is used outside of an application 3998168 [Andrew Or] In sc.runJob, actually clean the inner closure 9187066 [Andrew Or] Merge branch 'master' of github.com:apache/spark into closure-cleaner d889950 [Andrew Or] Revert "Bypass SerializationDebugger for now (SPARK-7180)" 9419efe [Andrew Or] Bypass SerializationDebugger for now (SPARK-7180) 6d4d3f1 [Andrew Or] Fix scala style? 4aab379 [Andrew Or] Merge branch 'master' of github.com:apache/spark into closure-cleaner e45e904 [Andrew Or] More minor updates (wording, renaming etc.) 8b71cdb [Andrew Or] Update a few comments eb127e5 [Andrew Or] Use private method tester for a few things a3aa465 [Andrew Or] Add more tests for individual closure cleaner operations e672170 [Andrew Or] Guard against potential infinite cycles in method visitor 6d36f38 [Andrew Or] Fix closure cleaner visibility 2106f12 [Andrew Or] Merge branch 'master' of github.com:apache/spark into closure-cleaner 263593d [Andrew Or] Finalize tests 06fd668 [Andrew Or] Make closure cleaning idempotent a4866e3 [Andrew Or] Add tests (still WIP) 438c68f [Andrew Or] Minor changes 2390a60 [Andrew Or] Feature flag this new behavior 86f7823 [Andrew Or] Implement transitive cleaning + add missing documentation --- .../apache/spark/util/ClosureCleaner.scala | 305 ++++++++-- .../spark/util/ClosureCleanerSuite.scala | 13 +- .../spark/util/ClosureCleanerSuite2.scala | 571 ++++++++++++++++++ 3 files changed, 831 insertions(+), 58 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index e3f52f6ff1e63..4ac0382d80815 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -19,17 +19,20 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.collection.mutable.Map -import scala.collection.mutable.Set +import scala.collection.mutable.{Map, Set} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.{Logging, SparkEnv, SparkException} +/** + * A cleaner that renders closures serializable if they can be done so safely. + */ private[spark] object ClosureCleaner extends Logging { + // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { + private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) @@ -55,10 +58,14 @@ private[spark] object ClosureCleaner extends Logging { private def getOuterClasses(obj: AnyRef): List[Class[_]] = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) - if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(f.get(obj)) - } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + val outer = f.get(obj) + // The outer pointer may be null if we have cleaned this closure before + if (outer != null) { + if (isClosure(f.getType)) { + return f.getType :: getOuterClasses(outer) + } else { + return f.getType :: Nil // Stop at the first $outer that is not a closure + } } } Nil @@ -68,16 +75,23 @@ private[spark] object ClosureCleaner extends Logging { private def getOuterObjects(obj: AnyRef): List[AnyRef] = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) - if (isClosure(f.getType)) { - return f.get(obj) :: getOuterObjects(f.get(obj)) - } else { - return f.get(obj) :: Nil // Stop at the first $outer that is not a closure + val outer = f.get(obj) + // The outer pointer may be null if we have cleaned this closure before + if (outer != null) { + if (isClosure(f.getType)) { + return outer :: getOuterObjects(outer) + } else { + return outer :: Nil // Stop at the first $outer that is not a closure + } } } Nil } - private def getInnerClasses(obj: AnyRef): List[Class[_]] = { + /** + * Return a list of classes that represent closures enclosed in the given closure object. + */ + private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) var stack = List[Class[_]](obj.getClass) while (!stack.isEmpty) { @@ -90,7 +104,7 @@ private[spark] object ClosureCleaner extends Logging { stack = cls :: stack } } - return (seen - obj.getClass).toList + (seen - obj.getClass).toList } private def createNullValue(cls: Class[_]): AnyRef = { @@ -101,21 +115,124 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef, checkSerializable: Boolean = true) { + /** + * Clean the given closure in place. + * + * More specifically, this renders the given closure serializable as long as it does not + * explicitly reference unserializable objects. + * + * @param closure the closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + */ + def clean( + closure: AnyRef, + checkSerializable: Boolean = true, + cleanTransitively: Boolean = true): Unit = { + clean(closure, checkSerializable, cleanTransitively, Map.empty) + } + + /** + * Helper method to clean the given closure in place. + * + * The mechanism is to traverse the hierarchy of enclosing closures and null out any + * references along the way that are not actually used by the starting closure, but are + * nevertheless included in the compiled anonymous classes. Note that it is unsafe to + * simply mutate the enclosing closures in place, as other code paths may depend on them. + * Instead, we clone each enclosing closure and set the parent pointers accordingly. + * + * By default, closures are cleaned transitively. This means we detect whether enclosing + * objects are actually referenced by the starting one, either directly or transitively, + * and, if not, sever these closures from the hierarchy. In other words, in addition to + * nulling out unused field references, we also null out any parent pointers that refer + * to enclosing objects not actually needed by the starting closure. We determine + * transitivity by tracing through the tree of all methods ultimately invoked by the + * inner closure and record all the fields referenced in the process. + * + * For instance, transitive cleaning is necessary in the following scenario: + * + * class SomethingNotSerializable { + * def someValue = 1 + * def scope(name: String)(body: => Unit) = body + * def someMethod(): Unit = scope("one") { + * def x = someValue + * def y = 2 + * scope("two") { println(y + 1) } + * } + * } + * + * In this example, scope "two" is not serializable because it references scope "one", which + * references SomethingNotSerializable. Note that, however, the body of scope "two" does not + * actually depend on SomethingNotSerializable. This means we can safely null out the parent + * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two" + * no longer references SomethingNotSerializable transitively. + * + * @param func the starting closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + * @param accessedFields a map from a class to a set of its fields that are accessed by + * the starting closure + */ + private def clean( + func: AnyRef, + checkSerializable: Boolean, + cleanTransitively: Boolean, + accessedFields: Map[Class[_], Set[String]]): Unit = { + + // TODO: clean all inner closures first. This requires us to find the inner objects. // TODO: cache outerClasses / innerClasses / accessedFields + + if (func == null) { + return + } + + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index val outerClasses = getOuterClasses(func) - val innerClasses = getInnerClasses(func) val outerObjects = getOuterObjects(func) - val accessedFields = Map[Class[_], Set[String]]() - + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + logDebug(" + declared fields: " + declaredFields.size) + declaredFields.foreach { f => logDebug(" " + f) } + logDebug(" + declared methods: " + declaredMethods.size) + declaredMethods.foreach { m => logDebug(" " + m) } + logDebug(" + inner classes: " + innerClasses.size) + innerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(" + outer classes: " + outerClasses.size) + outerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(" + outer objects: " + outerObjects.size) + outerObjects.foreach { o => logDebug(" " + o) } + + // Fail fast if we detect return statements in closures getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - // logInfo("accessedFields: " + accessedFields) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(s" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + for (cls <- outerClasses) { + accessedFields(cls) = Set[String]() + } + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } + } + + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } val inInterpreter = { try { @@ -126,34 +243,68 @@ private[spark] object ClosureCleaner extends Logging { } } + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null + var parent: AnyRef = null if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { // The closure is ultimately nested inside a class; keep the object of that // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}") + parent = outerPairs.head._2 // e.g. SparkContext outerPairs = outerPairs.tail + } else if (outerPairs.size > 0) { + logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}") + } else { + logDebug(" + there are no enclosing objects!") } + // Clone the closure objects themselves, nulling out any fields that are not // used in the closure we're working on or any of its inner closures. for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer, inInterpreter) + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = instantiateClass(cls, parent, inInterpreter) for (fieldName <- accessedFields(cls)) { val field = cls.getDeclaredField(fieldName) field.setAccessible(true) val value = field.get(obj) - // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); - field.set(outer, value) + field.set(clone, value) + } + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) } + parent = clone } - if (outer != null) { - // logInfo("2: Setting $outer on " + func.getClass + " to " + outer); + // Update the parent pointer ($outer) of this closure + if (parent != null) { val field = func.getClass.getDeclaredField("$outer") field.setAccessible(true) - field.set(func, outer) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - + + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + if (checkSerializable) { ensureSerializable(func) } @@ -167,15 +318,17 @@ private[spark] object ClosureCleaner extends Logging { } } - private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - // logInfo("Creating a " + cls + " with outer = " + outer) + private def instantiateClass( + cls: Class[_], + enclosingObject: AnyRef, + inInterpreter: Boolean): AnyRef = { if (!inInterpreter) { // This is a bona fide closure class, whose constructor has no effects // other than to set its fields, so use its constructor val cons = cls.getConstructors()(0) val params = cons.getParameterTypes.map(createNullValue).toArray - if (outer != null) { - params(0) = outer // First param is always outer object + if (enclosingObject != null) { + params(0) = enclosingObject // First param is always enclosing object } return cons.newInstance(params: _*).asInstanceOf[AnyRef] } else { @@ -184,19 +337,17 @@ private[spark] object ClosureCleaner extends Logging { val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() val newCtor = rf.newConstructorForSerialization(cls, parentCtor) val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (outer != null) { - // logInfo("3: Setting $outer on " + cls + " to " + outer); + if (enclosingObject != null) { val field = cls.getDeclaredField("$outer") field.setAccessible(true) - field.set(obj, outer) + field.set(obj, enclosingObject) } obj } } } -private[spark] -class ReturnStatementFinder extends ClassVisitor(ASM4) { +private class ReturnStatementFinder extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name.contains("apply")) { @@ -213,26 +364,65 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) { } } -private[spark] -class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { +/** Helper class to identify a method. */ +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * Find the fields accessed by a given class. + * + * The resulting fields are stored in the mutable map passed in through the constructor. + * This map is assumed to have its keys already populated with the classes of interest. + * + * @param fields the mutable map that stores the fields to return + * @param findTransitively if true, find fields indirectly referenced through method calls + * @param specificMethod if not empty, visit only this specific method + * @param visitedMethods a set of visited methods to avoid cycles + */ +private[util] class FieldAccessFinder( + fields: Map[Class[_], Set[String]], + findTransitively: Boolean, + specificMethod: Option[MethodIdentifier[_]] = None, + visitedMethods: Set[MethodIdentifier[_]] = Set.empty) + extends ClassVisitor(ASM4) { + + override def visitMethod( + access: Int, + name: String, + desc: String, + sig: String, + exceptions: Array[String]): MethodVisitor = { + + // If we are told to visit only a certain method and this is not the one, ignore it + if (specificMethod.isDefined && + (specificMethod.get.name != name || specificMethod.get.desc != desc)) { + return null + } + new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + fields(cl) += name } } } - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + fields(cl) += name + } + // Optionally visit other methods to find fields that are transitively referenced + if (findTransitively) { + val m = MethodIdentifier(cl, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + ClosureCleaner.getClassReader(cl).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + } } } } @@ -240,9 +430,14 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor } } -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null + // TODO: Recursively find inner closures that we indirectly reference, e.g. + // val closure1 = () = { () => 1 } + // val closure2 = () => { (1 to 5).map(closure1) } + // The second closure technically has two inner closures, but this finder only finds one + override def visit(version: Int, access: Int, name: String, sig: String, superName: String, interfaces: Array[String]) { myName = name diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index c47162779bbba..ff1bfe0774a2f 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -50,7 +50,7 @@ class ClosureCleanerSuite extends FunSuite { val obj = new TestClassWithNesting(1) assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 } - + test("toplevel return statements in closures are identified at cleaning time") { val ex = intercept[SparkException] { TestObjectWithBogusReturns.run() @@ -61,13 +61,20 @@ class ClosureCleanerSuite extends FunSuite { test("return statements from named functions nested in closures don't raise exceptions") { val result = TestObjectWithNestedReturns.run() - assert(result == 1) + assert(result === 1) } } // A non-serializable class we create in closures to make sure that we aren't // keeping references to unneeded variables from our outer closures. -class NonSerializable {} +class NonSerializable(val id: Int = -1) { + override def equals(other: Any): Boolean = { + other match { + case o: NonSerializable => id == o.id + case _ => false + } + } +} object TestObject { def run(): Int = { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala new file mode 100644 index 0000000000000..59456790e89f0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -0,0 +1,571 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.NotSerializableException + +import scala.collection.mutable + +import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester} + +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.serializer.SerializerInstance + +/** + * Another test suite for the closure cleaner that is finer-grained. + * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}. + */ +class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + + // Start a SparkContext so that the closure serializer is accessible + // We do not actually use this explicitly otherwise + private var sc: SparkContext = null + private var closureSerializer: SerializerInstance = null + + override def beforeAll(): Unit = { + sc = new SparkContext("local", "test") + closureSerializer = sc.env.closureSerializer.newInstance() + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + closureSerializer = null + } + + // Some fields and methods to reference in inner closures later + private val someSerializableValue = 1 + private val someNonSerializableValue = new NonSerializable + private def someSerializableMethod() = 1 + private def someNonSerializableMethod() = new NonSerializable + + /** Assert that the given closure is serializable (or not). */ + private def assertSerializable(closure: AnyRef, serializable: Boolean): Unit = { + if (serializable) { + closureSerializer.serialize(closure) + } else { + intercept[NotSerializableException] { + closureSerializer.serialize(closure) + } + } + } + + /** + * Helper method for testing whether closure cleaning works as expected. + * This cleans the given closure twice, with and without transitive cleaning. + * + * @param closure closure to test cleaning with + * @param serializableBefore if true, verify that the closure is serializable + * before cleaning, otherwise assert that it is not + * @param serializableAfter if true, assert that the closure is serializable + * after cleaning otherwise assert that it is not + */ + private def verifyCleaning( + closure: AnyRef, + serializableBefore: Boolean, + serializableAfter: Boolean): Unit = { + verifyCleaning(closure, serializableBefore, serializableAfter, transitive = true) + verifyCleaning(closure, serializableBefore, serializableAfter, transitive = false) + } + + /** Helper method for testing whether closure cleaning works as expected. */ + private def verifyCleaning( + closure: AnyRef, + serializableBefore: Boolean, + serializableAfter: Boolean, + transitive: Boolean): Unit = { + assertSerializable(closure, serializableBefore) + // If the resulting closure is not serializable even after + // cleaning, we expect ClosureCleaner to throw a SparkException + if (serializableAfter) { + ClosureCleaner.clean(closure, checkSerializable = true, transitive) + } else { + intercept[SparkException] { + ClosureCleaner.clean(closure, checkSerializable = true, transitive) + } + } + assertSerializable(closure, serializableAfter) + } + + /** + * Return the fields accessed by the given closure by class. + * This also optionally finds the fields transitively referenced through methods invocations. + */ + private def findAccessedFields( + closure: AnyRef, + outerClasses: Seq[Class[_]], + findTransitively: Boolean): Map[Class[_], Set[String]] = { + val fields = new mutable.HashMap[Class[_], mutable.Set[String]] + outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] } + ClosureCleaner.getClassReader(closure.getClass) + .accept(new FieldAccessFinder(fields, findTransitively), 0) + fields.mapValues(_.toSet).toMap + } + + // Accessors for private methods + private val _isClosure = PrivateMethod[Boolean]('isClosure) + private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses) + private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses) + private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects) + + private def isClosure(obj: AnyRef): Boolean = { + ClosureCleaner invokePrivate _isClosure(obj) + } + + private def getInnerClosureClasses(closure: AnyRef): List[Class[_]] = { + ClosureCleaner invokePrivate _getInnerClosureClasses(closure) + } + + private def getOuterClasses(closure: AnyRef): List[Class[_]] = { + ClosureCleaner invokePrivate _getOuterClasses(closure) + } + + private def getOuterObjects(closure: AnyRef): List[AnyRef] = { + ClosureCleaner invokePrivate _getOuterObjects(closure) + } + + test("get inner closure classes") { + val closure1 = () => 1 + val closure2 = () => { () => 1 } + val closure3 = (i: Int) => { + (1 to i).map { x => x + 1 }.filter { x => x > 5 } + } + val closure4 = (j: Int) => { + (1 to j).flatMap { x => + (1 to x).flatMap { y => + (1 to y).map { z => z + 1 } + } + } + } + val inner1 = getInnerClosureClasses(closure1) + val inner2 = getInnerClosureClasses(closure2) + val inner3 = getInnerClosureClasses(closure3) + val inner4 = getInnerClosureClasses(closure4) + assert(inner1.isEmpty) + assert(inner2.size === 1) + assert(inner3.size === 2) + assert(inner4.size === 3) + assert(inner2.forall(isClosure)) + assert(inner3.forall(isClosure)) + assert(inner4.forall(isClosure)) + } + + test("get outer classes and objects") { + val localValue = someSerializableValue + val closure1 = () => 1 + val closure2 = () => localValue + val closure3 = () => someSerializableValue + val closure4 = () => someSerializableMethod() + val outerClasses1 = getOuterClasses(closure1) + val outerClasses2 = getOuterClasses(closure2) + val outerClasses3 = getOuterClasses(closure3) + val outerClasses4 = getOuterClasses(closure4) + val outerObjects1 = getOuterObjects(closure1) + val outerObjects2 = getOuterObjects(closure2) + val outerObjects3 = getOuterObjects(closure3) + val outerObjects4 = getOuterObjects(closure4) + + // The classes and objects should have the same size + assert(outerClasses1.size === outerObjects1.size) + assert(outerClasses2.size === outerObjects2.size) + assert(outerClasses3.size === outerObjects3.size) + assert(outerClasses4.size === outerObjects4.size) + + // These do not have $outer pointers because they reference only local variables + assert(outerClasses1.isEmpty) + assert(outerClasses2.isEmpty) + + // These closures do have $outer pointers because they ultimately reference `this` + // The first $outer pointer refers to the closure defines this test (see FunSuite#test) + // The second $outer pointer refers to ClosureCleanerSuite2 + assert(outerClasses3.size === 2) + assert(outerClasses4.size === 2) + assert(isClosure(outerClasses3(0))) + assert(isClosure(outerClasses4(0))) + assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope + assert(outerClasses3(1) === this.getClass) + assert(outerClasses4(1) === this.getClass) + assert(outerObjects3(1) === this) + assert(outerObjects4(1) === this) + } + + test("get outer classes and objects with nesting") { + val localValue = someSerializableValue + + val test1 = () => { + val x = 1 + val closure1 = () => 1 + val closure2 = () => x + val outerClasses1 = getOuterClasses(closure1) + val outerClasses2 = getOuterClasses(closure2) + val outerObjects1 = getOuterObjects(closure1) + val outerObjects2 = getOuterObjects(closure2) + assert(outerClasses1.size === outerObjects1.size) + assert(outerClasses2.size === outerObjects2.size) + // These inner closures only reference local variables, and so do not have $outer pointers + assert(outerClasses1.isEmpty) + assert(outerClasses2.isEmpty) + } + + val test2 = () => { + def y = 1 + val closure1 = () => 1 + val closure2 = () => y + val closure3 = () => localValue + val outerClasses1 = getOuterClasses(closure1) + val outerClasses2 = getOuterClasses(closure2) + val outerClasses3 = getOuterClasses(closure3) + val outerObjects1 = getOuterObjects(closure1) + val outerObjects2 = getOuterObjects(closure2) + val outerObjects3 = getOuterObjects(closure3) + assert(outerClasses1.size === outerObjects1.size) + assert(outerClasses2.size === outerObjects2.size) + assert(outerClasses3.size === outerObjects3.size) + // Same as above, this closure only references local variables + assert(outerClasses1.isEmpty) + // This closure references the "test2" scope because it needs to find the method `y` + // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2 + assert(outerClasses2.size === 3) + // This closure references the "test2" scope because it needs to find the `localValue` + // defined outside of this scope + assert(outerClasses3.size === 3) + assert(isClosure(outerClasses2(0))) + assert(isClosure(outerClasses3(0))) + assert(isClosure(outerClasses2(1))) + assert(isClosure(outerClasses3(1))) + assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope + assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope + assert(outerClasses2(2) === this.getClass) + assert(outerClasses3(2) === this.getClass) + assert(outerObjects2(2) === this) + assert(outerObjects3(2) === this) + } + + test1() + test2() + } + + test("find accessed fields") { + val localValue = someSerializableValue + val closure1 = () => 1 + val closure2 = () => localValue + val closure3 = () => someSerializableValue + val outerClasses1 = getOuterClasses(closure1) + val outerClasses2 = getOuterClasses(closure2) + val outerClasses3 = getOuterClasses(closure3) + + val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) + val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) + val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false) + assert(fields1.isEmpty) + assert(fields2.isEmpty) + assert(fields3.size === 2) + // This corresponds to the "FunSuite#test" closure. This is empty because the + // `someSerializableValue` belongs to its parent (i.e. ClosureCleanerSuite2). + assert(fields3(outerClasses3(0)).isEmpty) + // This corresponds to the ClosureCleanerSuite2. This is also empty, however, + // because accessing a `ClosureCleanerSuite2#someSerializableValue` actually involves a + // method call. Since we do not find fields transitively, we will not recursively trace + // through the fields referenced by this method. + assert(fields3(outerClasses3(1)).isEmpty) + + val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true) + val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true) + val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true) + assert(fields1t.isEmpty) + assert(fields2t.isEmpty) + assert(fields3t.size === 2) + // Because we find fields transitively now, we are able to detect that we need the + // $outer pointer to get the field from the ClosureCleanerSuite2 + assert(fields3t(outerClasses3(0)).size === 1) + assert(fields3t(outerClasses3(0)).head === "$outer") + assert(fields3t(outerClasses3(1)).size === 1) + assert(fields3t(outerClasses3(1)).head.contains("someSerializableValue")) + } + + test("find accessed fields with nesting") { + val localValue = someSerializableValue + + val test1 = () => { + def a = localValue + 1 + val closure1 = () => 1 + val closure2 = () => a + val closure3 = () => localValue + val closure4 = () => someSerializableValue + val outerClasses1 = getOuterClasses(closure1) + val outerClasses2 = getOuterClasses(closure2) + val outerClasses3 = getOuterClasses(closure3) + val outerClasses4 = getOuterClasses(closure4) + + // First, find only fields accessed directly, not transitively, by these closures + val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) + val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) + val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false) + val fields4 = findAccessedFields(closure4, outerClasses4, findTransitively = false) + assert(fields1.isEmpty) + // Note that the size here represents the number of outer classes, not the number of fields + // "test1" < parameter of "FunSuite#test" < ClosureCleanerSuite2 + assert(fields2.size === 3) + // Since we do not find fields transitively here, we do not look into what `def a` references + assert(fields2(outerClasses2(0)).isEmpty) // This corresponds to the "test1" scope + assert(fields2(outerClasses2(1)).isEmpty) // This corresponds to the "FunSuite#test" scope + assert(fields2(outerClasses2(2)).isEmpty) // This corresponds to the ClosureCleanerSuite2 + assert(fields3.size === 3) + // Note that `localValue` is a field of the "test1" scope because `def a` references it, + // but NOT a field of the "FunSuite#test" scope because it is only a local variable there + assert(fields3(outerClasses3(0)).size === 1) + assert(fields3(outerClasses3(0)).head.contains("localValue")) + assert(fields3(outerClasses3(1)).isEmpty) + assert(fields3(outerClasses3(2)).isEmpty) + assert(fields4.size === 3) + // Because `val someSerializableValue` is an instance variable, even an explicit reference + // here actually involves a method call to access the underlying value of the variable. + // Because we are not finding fields transitively here, we do not consider the fields + // accessed by this "method" (i.e. the val's accessor). + assert(fields4(outerClasses4(0)).isEmpty) + assert(fields4(outerClasses4(1)).isEmpty) + assert(fields4(outerClasses4(2)).isEmpty) + + // Now do the same, but find fields that the closures transitively reference + val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true) + val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true) + val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true) + val fields4t = findAccessedFields(closure4, outerClasses4, findTransitively = true) + assert(fields1t.isEmpty) + assert(fields2t.size === 3) + assert(fields2t(outerClasses2(0)).size === 1) // `def a` references `localValue` + assert(fields2t(outerClasses2(0)).head.contains("localValue")) + assert(fields2t(outerClasses2(1)).isEmpty) + assert(fields2t(outerClasses2(2)).isEmpty) + assert(fields3t.size === 3) + assert(fields3t(outerClasses3(0)).size === 1) // as before + assert(fields3t(outerClasses3(0)).head.contains("localValue")) + assert(fields3t(outerClasses3(1)).isEmpty) + assert(fields3t(outerClasses3(2)).isEmpty) + assert(fields4t.size === 3) + // Through a series of method calls, we are able to detect that we ultimately access + // ClosureCleanerSuite2's field `someSerializableValue`. Along the way, we also accessed + // a few $outer parent pointers to get to the outermost object. + assert(fields4t(outerClasses4(0)) === Set("$outer")) + assert(fields4t(outerClasses4(1)) === Set("$outer")) + assert(fields4t(outerClasses4(2)).size === 1) + assert(fields4t(outerClasses4(2)).head.contains("someSerializableValue")) + } + + test1() + } + + test("clean basic serializable closures") { + val localValue = someSerializableValue + val closure1 = () => 1 + val closure2 = () => Array[String]("a", "b", "c") + val closure3 = (s: String, arr: Array[Long]) => s + arr.mkString(", ") + val closure4 = () => localValue + val closure5 = () => new NonSerializable(5) // we're just serializing the class information + val closure1r = closure1() + val closure2r = closure2() + val closure3r = closure3("g", Array(1, 5, 8)) + val closure4r = closure4() + val closure5r = closure5() + + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure2, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure3, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure4, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure5, serializableBefore = true, serializableAfter = true) + + // Verify that closures can still be invoked and the result still the same + assert(closure1() === closure1r) + assert(closure2() === closure2r) + assert(closure3("g", Array(1, 5, 8)) === closure3r) + assert(closure4() === closure4r) + assert(closure5() === closure5r) + } + + test("clean basic non-serializable closures") { + val closure1 = () => this // ClosureCleanerSuite2 is not serializable + val closure5 = () => someSerializableValue + val closure3 = () => someSerializableMethod() + val closure4 = () => someNonSerializableValue + val closure2 = () => someNonSerializableMethod() + + // These are not cleanable because they ultimately reference the ClosureCleanerSuite2 + verifyCleaning(closure1, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure3, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure4, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure5, serializableBefore = false, serializableAfter = false) + } + + test("clean basic nested serializable closures") { + val localValue = someSerializableValue + val closure1 = (i: Int) => { + (1 to i).map { x => x + localValue } // 1 level of nesting + } + val closure2 = (j: Int) => { + (1 to j).flatMap { x => + (1 to x).map { y => y + localValue } // 2 levels + } + } + val closure3 = (k: Int, l: Int, m: Int) => { + (1 to k).flatMap(closure2) ++ // 4 levels + (1 to l).flatMap(closure1) ++ // 3 levels + (1 to m).map { x => x + 1 } // 2 levels + } + val closure1r = closure1(1) + val closure2r = closure2(2) + val closure3r = closure3(3, 4, 5) + + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure2, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure3, serializableBefore = true, serializableAfter = true) + + // Verify that closures can still be invoked and the result still the same + assert(closure1(1) === closure1r) + assert(closure2(2) === closure2r) + assert(closure3(3, 4, 5) === closure3r) + } + + test("clean basic nested non-serializable closures") { + def localSerializableMethod(): Int = someSerializableValue + val localNonSerializableValue = someNonSerializableValue + // These closures ultimately reference the ClosureCleanerSuite2 + // Note that even accessing `val` that is an instance variable involves a method call + val closure1 = (i: Int) => { (1 to i).map { x => x + someSerializableValue } } + val closure2 = (j: Int) => { (1 to j).map { x => x + someSerializableMethod() } } + val closure4 = (k: Int) => { (1 to k).map { x => x + localSerializableMethod() } } + // This closure references a local non-serializable value + val closure3 = (l: Int) => { (1 to l).map { x => localNonSerializableValue } } + // This is non-serializable no matter how many levels we nest it + val closure5 = (m: Int) => { + (1 to m).foreach { x => + (1 to x).foreach { y => + (1 to y).foreach { z => + someSerializableValue + } + } + } + } + + verifyCleaning(closure1, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure3, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure4, serializableBefore = false, serializableAfter = false) + verifyCleaning(closure5, serializableBefore = false, serializableAfter = false) + } + + test("clean complicated nested serializable closures") { + val localValue = someSerializableValue + + // Here we assume that if the outer closure is serializable, + // then all inner closures must also be serializable + + // Reference local fields from all levels + val closure1 = (i: Int) => { + val a = 1 + (1 to i).flatMap { x => + val b = a + 1 + (1 to x).map { y => + y + a + b + localValue + } + } + } + + // Reference local fields and methods from all levels within the outermost closure + val closure2 = (i: Int) => { + val a1 = 1 + def a2 = 2 + (1 to i).flatMap { x => + val b1 = a1 + 1 + def b2 = a2 + 1 + (1 to x).map { y => + // If this references a method outside the outermost closure, then it will try to pull + // in the ClosureCleanerSuite2. This is why `localValue` here must be a local `val`. + y + a1 + a2 + b1 + b2 + localValue + } + } + } + + val closure1r = closure1(1) + val closure2r = closure2(2) + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + verifyCleaning(closure2, serializableBefore = true, serializableAfter = true) + assert(closure1(1) == closure1r) + assert(closure2(2) == closure2r) + } + + test("clean complicated nested non-serializable closures") { + val localValue = someSerializableValue + + // Note that we are not interested in cleaning the outer closures here (they are not cleanable) + // The only reason why they exist is to nest the inner closures + + val test1 = () => { + val a = localValue + val b = sc + val inner1 = (x: Int) => x + a + b.hashCode() + val inner2 = (x: Int) => x + a + + // This closure explicitly references a non-serializable field + // There is no way to clean it + verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) + + // This closure is serializable to begin with since it does not need a pointer to + // the outer closure (it only references local variables) + verifyCleaning(inner2, serializableBefore = true, serializableAfter = true) + } + + // Same as above, but the `val a` becomes `def a` + // The difference here is that all inner closures now have pointers to the outer closure + val test2 = () => { + def a = localValue + val b = sc + val inner1 = (x: Int) => x + a + b.hashCode() + val inner2 = (x: Int) => x + a + + // As before, this closure is neither serializable nor cleanable + verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) + + // This closure is no longer serializable because it now has a pointer to the outer closure, + // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. + // If we do not clean transitively, we will not null out this indirect reference. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = false, transitive = false) + + // If we clean transitively, we will find that method `a` does not actually reference the + // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out + // the outer closure's parent pointer. This will make `inner2` serializable. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = true, transitive = true) + } + + // Same as above, but with more levels of nesting + val test3 = () => { () => test1() } + val test4 = () => { () => test2() } + val test5 = () => { () => { () => test3() } } + val test6 = () => { () => { () => test4() } } + + test1() + test2() + test3()() + test4()() + test5()()() + test6()()() + } + +} From ecc6eb50a59172dc132bd8f97957734f6f009024 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 2 May 2015 01:53:14 -0700 Subject: [PATCH 27/91] [SPARK-7315] [STREAMING] [TEST] Fix flaky WALBackedBlockRDDSuite `FileUtils.getTempDirectoryPath()` path may or may not exist. We want to make sure that it does not exist. Author: Tathagata Das Closes #5853 from tdas/SPARK-7315 and squashes the following commits: 141afd5 [Tathagata Das] Removed use of FileUtils b08d4f1 [Tathagata Das] Fix flaky WALBackedBlockRDDSuite --- .../streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ebdf418f4ab6a..f4c8046e8a1a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.streaming.rdd +import java.io.File import java.nio.ByteBuffer +import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} @@ -108,9 +108,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // writing log data. However, the directory is not needed if data needs to be read, hence // a dummy path is provided to satisfy the method parameter requirements. // FileBasedWriteAheadLog will not create any file or directory at that path. - val dummyDirectory = FileUtils.getTempDirectoryPath() + // FileBasedWriteAheadLog will not create any file or directory at that path. Also, + // this dummy directory should not already exist otherwise the WAL will try to recover + // past events from the directory and throw errors. + val nonExistentDirectory = new File( + System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath writeAheadLog = WriteAheadLogUtils.createLogForReceiver( - SparkEnv.get.conf, dummyDirectory, hadoopConf) + SparkEnv.get.conf, nonExistentDirectory, hadoopConf) dataRead = writeAheadLog.read(partition.walRecordHandle) } catch { case NonFatal(e) => From 856a571ef4f0338ccefb3b792e1e96b8f15a0884 Mon Sep 17 00:00:00 2001 From: Dean Chen Date: Sat, 2 May 2015 23:04:13 +0100 Subject: [PATCH 28/91] [SPARK-3444] Fix typo in Dataframes.py introduced in [] Author: Dean Chen Closes #5866 from deanchen/patch-1 and squashes the following commits: 0028bc4 [Dean Chen] Fix typo in Dataframes.py introduced in [SPARK-3444] --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e9fd17ed4ce94..8ddcff8fcdf98 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1,6 +1,6 @@ # # Licensed to the Apache Software Foundation (ASF) under one or more -# contir[butor license agreements. See the NOTICE file distributed with +# contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with From da303526e54e9a0adfedb49417f383cde7870a69 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sat, 2 May 2015 23:05:51 +0100 Subject: [PATCH 29/91] [SPARK-7323] [SPARK CORE] Use insertAll instead of insert while merging combiners in reducer Author: Mridul Muralidharan Closes #5862 from mridulm/optimize_aggregator and squashes the following commits: 61cf43a [Mridul Muralidharan] Use insertAll instead of insert - much more expensive to do it per tuple --- core/src/main/scala/org/apache/spark/Aggregator.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 3b684bbeceaf2..af9765d313e9e 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -88,10 +88,7 @@ case class Aggregator[K, V, C] ( combiners.iterator } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - while (iter.hasNext) { - val pair = iter.next() - combiners.insert(pair._1, pair._2) - } + combiners.insertAll(iter) // Update task metrics if context is not null // TODO: Make context non-optional in a future release Option(context).foreach { c => From bfcd528d6f5a5ebe61e0fcca890143e9a3c7f7f9 Mon Sep 17 00:00:00 2001 From: Ye Xianjin Date: Sat, 2 May 2015 23:08:09 +0100 Subject: [PATCH 30/91] [SPARK-6030] [CORE] Using simulated field layout method to compute class shellSize SizeEstimator gives wrong result for Integer on 64bit JVM with UseCompressedOops on, this pr fixes that. For more details, please refer [SPARK-6030](https://issues.apache.org/jira/browse/SPARK-6030) sryza, I noticed there is a pr to expose SizeEstimator, maybe that should be waited by this pr get merged if we confirm this problem. And shivaram would you mind to review this pr since you contribute related code. Also cc to srowen and mateiz Author: Ye Xianjin Closes #4783 from advancedxy/SPARK-6030 and squashes the following commits: c4dcb41 [Ye Xianjin] Add super.beforeEach in the beforeEach method to make the trait stackable.. Remove useless leading whitespace. 3f80640 [Ye Xianjin] The size of Integer class changes from 24 to 16 on a 64-bit JVM with -UseCompressedOops flag on after the fix. I don't how 100000 was originally calculated, It looks like 100000 is the magic number which makes sure spilling. Because of the size change, It fails because there is no spilling at all. Change the number to a slightly larger number fixes that. e849d2d [Ye Xianjin] Merge two shellSize assignments into one. Add some explanation to alignSizeUp method. 85a0b51 [Ye Xianjin] Fix typos and update wording in comments. Using alignSizeUp to compute alignSize. d27eb77 [Ye Xianjin] Add some detailed comments in the code. Add some test cases. It's very difficult to design test cases as the final object alignment will hide a lot of filed layout details if we just considering the whole size. 842aed1 [Ye Xianjin] primitiveSize(cls) can just return Int. Use a simplified class field layout method to calculate class instance size. Will add more documents and test cases. Add a new alignSizeUp function which uses bitwise operators to speedup. 62e8ab4 [Ye Xianjin] Don't alignSize for objects' shellSize, alignSize when added to state.size. Add some primitive wrapper objects size tests. --- .../org/apache/spark/util/SizeEstimator.scala | 61 ++++++++++++++++--- .../spark/util/SizeEstimatorSuite.scala | 47 +++++++++++++- .../util/collection/ExternalSorterSuite.scala | 10 +-- 3 files changed, 100 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 4dd7ab9e0767b..d91c3294ddb8b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -47,6 +47,11 @@ private[spark] object SizeEstimator extends Logging { private val FLOAT_SIZE = 4 private val DOUBLE_SIZE = 8 + // Fields can be primitive types, sizes are: 1, 2, 4, 8. Or fields can be pointers. The size of + // a pointer is 4 or 8 depending on the JVM (32-bit or 64-bit) and UseCompressedOops flag. + // The sizes should be in descending order, as we will use that information for fields placement. + private val fieldSizes = List(8, 4, 2, 1) + // Alignment boundary for objects // TODO: Is this arch dependent ? private val ALIGN_SIZE = 8 @@ -171,7 +176,7 @@ private[spark] object SizeEstimator extends Logging { // general all ClassLoaders and Classes will be shared between objects anyway. } else { val classInfo = getClassInfo(cls) - state.size += classInfo.shellSize + state.size += alignSize(classInfo.shellSize) for (field <- classInfo.pointerFields) { state.enqueue(field.get(obj)) } @@ -237,8 +242,8 @@ private[spark] object SizeEstimator extends Logging { } size } - - private def primitiveSize(cls: Class[_]): Long = { + + private def primitiveSize(cls: Class[_]): Int = { if (cls == classOf[Byte]) { BYTE_SIZE } else if (cls == classOf[Boolean]) { @@ -274,21 +279,50 @@ private[spark] object SizeEstimator extends Logging { val parent = getClassInfo(cls.getSuperclass) var shellSize = parent.shellSize var pointerFields = parent.pointerFields + val sizeCount = Array.fill(fieldSizes.max + 1)(0) + // iterate through the fields of this class and gather information. for (field <- cls.getDeclaredFields) { if (!Modifier.isStatic(field.getModifiers)) { val fieldClass = field.getType if (fieldClass.isPrimitive) { - shellSize += primitiveSize(fieldClass) + sizeCount(primitiveSize(fieldClass)) += 1 } else { field.setAccessible(true) // Enable future get()'s on this field - shellSize += pointerSize + sizeCount(pointerSize) += 1 pointerFields = field :: pointerFields } } } - shellSize = alignSize(shellSize) + // Based on the simulated field layout code in Aleksey Shipilev's report: + // http://cr.openjdk.java.net/~shade/papers/2013-shipilev-fieldlayout-latest.pdf + // The code is in Figure 9. + // The simplified idea of field layout consists of 4 parts (see more details in the report): + // + // 1. field alignment: HotSpot lays out the fields aligned by their size. + // 2. object alignment: HotSpot rounds instance size up to 8 bytes + // 3. consistent fields layouts throughout the hierarchy: This means we should layout + // superclass first. And we can use superclass's shellSize as a starting point to layout the + // other fields in this class. + // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed + // with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322 + // + // The real world field layout is much more complicated. There are three kinds of fields + // order in Java 8. And we don't consider the @contended annotation introduced by Java 8. + // see the HotSpot classloader code, layout_fields method for more details. + // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp + var alignedSize = shellSize + for (size <- fieldSizes if sizeCount(size) > 0) { + val count = sizeCount(size) + // If there are internal gaps, smaller field can fit in. + alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) + shellSize += size * count + } + + // Should choose a larger size to be new shellSize and clearly alignedSize >= shellSize, and + // round up the instance filed blocks + shellSize = alignSizeUp(alignedSize, pointerSize) // Create and cache a new ClassInfo val newInfo = new ClassInfo(shellSize, pointerFields) @@ -296,8 +330,15 @@ private[spark] object SizeEstimator extends Logging { newInfo } - private def alignSize(size: Long): Long = { - val rem = size % ALIGN_SIZE - if (rem == 0) size else (size + ALIGN_SIZE - rem) - } + private def alignSize(size: Long): Long = alignSizeUp(size, ALIGN_SIZE) + + /** + * Compute aligned size. The alignSize must be 2^n, otherwise the result will be wrong. + * When alignSize = 2^n, alignSize - 1 = 2^n - 1. The binary representation of (alignSize - 1) + * will only have n trailing 1s(0b00...001..1). ~(alignSize - 1) will be 0b11..110..0. Hence, + * (size + alignSize - 1) & ~(alignSize - 1) will set the last n bits to zeros, which leads to + * multiple of alignSize. + */ + private def alignSizeUp(size: Long, alignSize: Int): Long = + (size + alignSize - 1) & ~(alignSize - 1) } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 28915bd53354e..133a76f28e000 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -36,6 +36,15 @@ class DummyClass4(val d: DummyClass3) { val x: Int = 0 } +// dummy class to show class field blocks alignment. +class DummyClass5 extends DummyClass1 { + val x: Boolean = true +} + +class DummyClass6 extends DummyClass5 { + val y: Boolean = true +} + object DummyString { def apply(str: String) : DummyString = new DummyString(str.toArray) } @@ -50,6 +59,7 @@ class SizeEstimatorSuite override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + super.beforeEach() System.setProperty("os.arch", "amd64") System.setProperty("spark.test.useCompressedOops", "true") } @@ -62,6 +72,22 @@ class SizeEstimatorSuite assertResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } + test("primitive wrapper objects") { + assertResult(16)(SizeEstimator.estimate(new java.lang.Boolean(true))) + assertResult(16)(SizeEstimator.estimate(new java.lang.Byte("1"))) + assertResult(16)(SizeEstimator.estimate(new java.lang.Character('1'))) + assertResult(16)(SizeEstimator.estimate(new java.lang.Short("1"))) + assertResult(16)(SizeEstimator.estimate(new java.lang.Integer(1))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) + assertResult(16)(SizeEstimator.estimate(new java.lang.Float(1.0))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + } + + test("class field blocks rounding") { + assertResult(16)(SizeEstimator.estimate(new DummyClass5)) + assertResult(24)(SizeEstimator.estimate(new DummyClass6)) + } + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { @@ -102,18 +128,18 @@ class SizeEstimatorSuite val arr = new Array[Char](100000) assertResult(200016)(SizeEstimator.estimate(arr)) assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr)))) - + val buf = new ArrayBuffer[DummyString]() for (i <- 0 until 5000) { buf.append(new DummyString(new Array[Char](10))) } assertResult(340016)(SizeEstimator.estimate(buf.toArray)) - + for (i <- 0 until 5000) { buf.append(new DummyString(arr)) } assertResult(683912)(SizeEstimator.estimate(buf.toArray)) - + // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 // 10 pointers plus 8-byte object @@ -155,5 +181,20 @@ class SizeEstimatorSuite assertResult(64)(SizeEstimator.estimate(DummyString("a"))) assertResult(64)(SizeEstimator.estimate(DummyString("ab"))) assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) + + // primitive wrapper classes + assertResult(24)(SizeEstimator.estimate(new java.lang.Boolean(true))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Byte("1"))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Character('1'))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Short("1"))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Integer(1))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Float(1.0))) + assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + } + + test("class field blocks rounding on 64-bit VM without useCompressedOops") { + assertResult(24)(SizeEstimator.estimate(new DummyClass5)) + assertResult(32)(SizeEstimator.estimate(new DummyClass6)) } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 20fd22b78ef5d..7a98723bc6472 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -377,7 +377,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) assertDidNotBypassMergeSort(sorter) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) + sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() assert(diskBlockManager.getAllBlocks().length === 0) @@ -385,9 +385,9 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter2 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) assertDidNotBypassMergeSort(sorter2) - sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) + sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) sorter2.stop() assert(diskBlockManager.getAllBlocks().length === 0) } @@ -428,8 +428,8 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe None, Some(new HashPartitioner(3)), Some(ord), None) assertDidNotBypassMergeSort(sorter) intercept[SparkException] { - sorter.insertAll((0 until 100000).iterator.map(i => { - if (i == 99990) { + sorter.insertAll((0 until 120000).iterator.map(i => { + if (i == 119990) { throw new SparkException("Intentional failure") } (i, i) From 82c8c37c098e5886da65cea3108737744e270b91 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 2 May 2015 23:10:35 +0100 Subject: [PATCH 31/91] [MINOR] [HIVE] Fix QueryPartitionSuite. At least in the version of Hive I tested on, the test was deleting a temp directory generated by Hive instead of one containing partition data. So fix the filter to only consider partition directories when deciding what to delete. Author: Marcelo Vanzin Closes #5854 from vanzin/hive-test-fix and squashes the following commits: 7594ae9 [Marcelo Vanzin] Fix typo. 729fa80 [Marcelo Vanzin] [minor] [hive] Fix QueryPartitionSuite. --- .../org/apache/spark/sql/hive/QueryPartitionSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index a787fa5546e76..4990092df6a99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest { import org.apache.spark.sql.hive.test.TestHive.implicits._ - test("SPARK-5068: query data when path doesn't exists"){ + test("SPARK-5068: query data when path doesn't exist"){ val testData = TestHive.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") @@ -52,8 +52,9 @@ class QueryPartitionSuite extends QueryTest { ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) // delete the path of one partition - val folders = tmpDir.listFiles.filter(_.isDirectory) - Utils.deleteRecursively(folders(0)) + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } // test for after delete the path checkAnswer(sql("select key,value from table_with_partition"), From 5d6b90d939d281130c786be38fd1794c74391b08 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sat, 2 May 2015 15:20:07 -0700 Subject: [PATCH 32/91] [SPARK-5213] [SQL] Pluggable SQL Parser Support based on #4015, we should not delete `sqlParser` from sqlcontext, that leads to mima failed. Users implement dialect to give a fallback for `sqlParser` and we should construct `sqlParser` in sqlcontext according to the dialect `protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_))` Author: Cheng Hao Author: scwf Closes #5827 from scwf/sqlparser1 and squashes the following commits: 81b9737 [scwf] comment fix 0878bd1 [scwf] remove comments c19780b [scwf] fix mima tests c2895cf [scwf] Merge branch 'master' of https://github.com/apache/spark into sqlparser1 493775c [Cheng Hao] update the code as feedback 81a731f [Cheng Hao] remove the unecessary comment aab0b0b [Cheng Hao] polish the code a little bit 49b9d81 [Cheng Hao] shrink the comment for rebasing --- .../sql/catalyst/AbstractSparkSQLParser.scala | 11 ++- .../apache/spark/sql/catalyst/Dialect.scala | 33 ++++++++ .../spark/sql/catalyst/errors/package.scala | 2 + .../org/apache/spark/sql/SQLContext.scala | 76 ++++++++++++++++--- .../org/apache/spark/sql/sources/ddl.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++ .../apache/spark/sql/hive/HiveContext.scala | 41 ++++++---- .../apache/spark/sql/hive/test/TestHive.scala | 5 +- .../sql/hive/execution/SQLQuerySuite.scala | 39 +++++++++- 9 files changed, 196 insertions(+), 39 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 1f3c02478bd68..2eb3e167baad5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -25,10 +25,6 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ -private[sql] object KeywordNormalizer { - def apply(str: String): String = str.toLowerCase() -} - private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { @@ -42,7 +38,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize: String = KeywordNormalizer(str) + def normalize: String = lexical.normalizeKeyword(str) def parser: Parser[String] = normalize } @@ -90,13 +86,16 @@ class SqlLexical extends StdLexical { reserved ++= keywords } + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" ) protected override def processIdent(name: String) = { - val token = KeywordNormalizer(name) + val token = normalizeKeyword(name) if (reserved contains token) Keyword(token) else Identifier(name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala new file mode 100644 index 0000000000000..977003493d471 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Root class of SQL Parser Dialect, and we don't guarantee the binary + * compatibility for the future release, let's keep it as the internal + * interface for advanced user. + * + */ +@DeveloperApi +abstract class Dialect { + // this is the main function that will be implemented by sql parser. + def parse(sqlText: String): LogicalPlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index bdeb660b1ecb7..0fd4f9b374ee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -38,6 +38,8 @@ package object errors { } } + class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause) + /** * Wraps any exceptions that are thrown while executing `f` in a * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5116fcefd4bf2..7eabb93c1e3d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.google.common.reflect.TypeToken @@ -32,9 +33,11 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.Dialect import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} @@ -44,6 +47,42 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.{Partition, SparkContext} +/** + * Currently we support the default dialect named "sql", associated with the class + * [[DefaultDialect]] + * + * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: + * {{{ + *-- switch to "hiveql" dialect + * spark-sql>SET spark.sql.dialect=hiveql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- switch to "sql" dialect + * spark-sql>SET spark.sql.dialect=sql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- register the new SQL dialect + * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- register the non-exist SQL dialect + * spark-sql> SET spark.sql.dialect=NotExistedClass; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- Exception will be thrown and switch to dialect + *-- "sql" (for SQLContext) or + *-- "hiveql" (for HiveContext) + * }}} + */ +private[spark] class DefaultDialect extends Dialect { + @transient + protected val sqlParser = new catalyst.SqlParser + + override def parse(sqlText: String): LogicalPlan = { + sqlParser.parse(sqlText) + } +} + /** * The entry point for working with structured data (rows and columns) in Spark. Allows the * creation of [[DataFrame]] objects as well as the execution of SQL queries. @@ -135,14 +174,27 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) @transient - protected[sql] val sqlParser = { - val fallback = new catalyst.SqlParser - new SparkSQLParser(fallback.parse(_)) + protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_)) + + protected[sql] def getSQLDialect(): Dialect = { + try { + val clazz = Utils.classForName(dialectClassName) + clazz.newInstance().asInstanceOf[Dialect] + } catch { + case NonFatal(e) => + // Since we didn't find the available SQL Dialect, it will fail even for SET command: + // SET spark.sql.dialect=sql; Let's reset as default dialect automatically. + val dialect = conf.dialect + // reset the sql dialect + conf.unsetConf(SQLConf.DIALECT) + // throw out the exception, and the default sql dialect will take effect for next query. + throw new DialectException( + s"""Instantiating dialect '$dialect' failed. + |Reverting to default dialect '${conf.dialect}'""".stripMargin, e) + } } - protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql)) - } + protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) @@ -156,6 +208,12 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val defaultSession = createSession() + protected[sql] def dialectClassName = if (conf.dialect == "sql") { + classOf[DefaultDialect].getCanonicalName + } else { + conf.dialect + } + sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) case _ => @@ -931,11 +989,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group basic */ def sql(sqlText: String): DataFrame = { - if (conf.dialect == "sql") { - DataFrame(this, parseSql(sqlText)) - } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}") - } + DataFrame(this, parseSql(sqlText)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index e7a0685e013d8..1abf3aa51cb25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -38,12 +38,12 @@ private[sql] class DDLParser( parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with DataTypeParser with Logging { - def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { try { - Some(parse(input)) + parse(input) } catch { case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => None + case _ if !exceptionOnError => parseQuery(input) case x: Throwable => throw x } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d8e7cdbd3a94e..0ab8558c1db13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,13 +19,18 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} + import org.apache.spark.sql.types._ +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultDialect + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -74,6 +79,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("SQL Dialect Switching to a new SQL parser") { + val newContext = new SQLContext(TestSQLContext.sparkContext) + newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) + assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) + } + + test("SQL Dialect Switch to an invalid parser with alias") { + val newContext = new SQLContext(TestSQLContext.sparkContext) + newContext.sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + newContext.sql("SELECT 1") + } + // test if the dialect set back to DefaultSQLDialect + assert(newContext.getSQLDialect().getClass === classOf[DefaultDialect]) + } + test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index dd06b2620c5ee..1d8d0b5c322ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} import java.sql.Timestamp +import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.spark.sql.catalyst.Dialect + import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -42,6 +45,15 @@ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNative import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} import org.apache.spark.sql.types._ +/** + * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext + */ +private[hive] class HiveQLDialect extends Dialect { + override def parse(sqlText: String): LogicalPlan = { + HiveQl.parseSql(sqlText) + } +} + /** * An instance of the Spark SQL execution engine that integrates with data stored in Hive. * Configuration for Hive is read from hive-site.xml on the classpath. @@ -81,25 +93,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertCTAS: Boolean = getConf("spark.sql.hive.convertCTAS", "false").toBoolean - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - @transient - protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_)) - - override def sql(sqlText: String): DataFrame = { - val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) - // TODO: Create a framework for registering parsers instead of just hardcoding if statements. - if (conf.dialect == "sql") { - super.sql(substituted) - } else if (conf.dialect == "hiveql") { - val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false) - DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) - } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") - } + protected[sql] lazy val substitutor = new VariableSubstitution() + + protected[sql] override def parseSql(sql: String): LogicalPlan = { + super.parseSql(substitutor.substitute(hiveconf, sql)) } + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -356,6 +359,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") { + classOf[HiveQLDialect].getCanonicalName + } else { + super.dialectClassName + } + @transient private val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9f17bca083d13..edeab5158df62 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -107,7 +107,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + + // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. + // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" + override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4f8d0ac0e7656..630dec8fa05a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.hive.{MetastoreRelation, HiveShim} +import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.DefaultDialect +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -45,6 +48,9 @@ case class Order( state: String, month: Int) +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultDialect + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is @@ -229,6 +235,35 @@ class SQLQuerySuite extends QueryTest { setConf("spark.sql.hive.convertCTAS", originalConf) } + test("SQL Dialect Switching") { + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(getSQLDialect().getClass === classOf[MyDialect]) + assert(sql("SELECT 1").collect() === Array(Row(1))) + + // set the dialect back to the DefaultSQLDialect + sql("SET spark.sql.dialect=sql") + assert(getSQLDialect().getClass === classOf[DefaultDialect]) + sql("SET spark.sql.dialect=hiveql") + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + + // set invalid dialect + sql("SET spark.sql.dialect.abc=MyTestClass") + sql("SET spark.sql.dialect=abc") + intercept[Exception] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + getSQLDialect().getClass === classOf[HiveQLDialect] + + sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( From ea841efc5a67e9a64a4ec803d31e4023b565c327 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sat, 2 May 2015 23:41:14 +0100 Subject: [PATCH 33/91] [SPARK-7255] [STREAMING] [DOCUMENTATION] Added documentation for spark.streaming.kafka.maxRetries Added documentation for spark.streaming.kafka.maxRetries Author: BenFradet Closes #5808 from BenFradet/master and squashes the following commits: cc72e7a [BenFradet] updated doc for spark.streaming.kafka.maxRetries to explain the default value 18f823e [BenFradet] Added "consecutive" to the spark.streaming.kafka.maxRetries doc 597fdeb [BenFradet] Mention that spark.streaming.kafka.maxRetries only applies to the direct kafka api 0efad39 [BenFradet] Added documentation for spark.streaming.kafka.maxRetries --- docs/configuration.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 7239b252a9fcc..64066bc0d70cd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1464,6 +1464,16 @@ Apart from these, the following properties are also available, and may be useful for more details. + + spark.streaming.kafka.maxRetries + 1 + + Maximum number of consecutive retries the driver will make in order to find + the latest offsets on the leader of each partition (a default value of 1 + means that the driver will make a maximum of 2 attempts). Only applies to + the new Kafka direct stream API. + + #### Cluster Managers From 49549d5a1a867c3ba25f5e4aec351d4102444bc0 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Sun, 3 May 2015 00:47:47 +0100 Subject: [PATCH 34/91] [SPARK-7031] [THRIFTSERVER] let thrift server take SPARK_DAEMON_MEMORY and SPARK_DAEMON_JAVA_OPTS We should let Thrift Server take these two parameters as it is a daemon. And it is better to read driver-related configs as an app submited by spark-submit. https://issues.apache.org/jira/browse/SPARK-7031 Author: WangTaoTheTonic Closes #5609 from WangTaoTheTonic/SPARK-7031 and squashes the following commits: 8d3fc16 [WangTaoTheTonic] indent 035069b [WangTaoTheTonic] better code style d3ddfb6 [WangTaoTheTonic] revert the unnecessary changes in suite 624e652 [WangTaoTheTonic] fix break tests 0565831 [WangTaoTheTonic] fix failed tests 4fb25ed [WangTaoTheTonic] let thrift server take SPARK_DAEMON_MEMORY and SPARK_DAEMON_JAVA_OPTS --- .../launcher/SparkSubmitCommandBuilder.java | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index a73c9c87e3126..7d387d406edae 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -190,6 +190,10 @@ private List buildSparkSubmitCommand(Map env) throws IOE firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; List cmd = buildJavaCommand(extraClassPath); + // Take Thrift Server as daemon + if (isThriftServer(mainClass)) { + addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS")); + } addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); @@ -201,7 +205,11 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable // - default value (512m) - String memory = firstNonEmpty(firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + // Take Thrift Server as daemon + String tsMemory = + isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; + String memory = firstNonEmpty(tsMemory, + firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); @@ -292,6 +300,15 @@ private boolean isClientMode(Properties userProps) { (!userMaster.equals("yarn-cluster") && deployMode == null); } + /** + * Return whether the given main class represents a thrift server. + */ + private boolean isThriftServer(String mainClass) { + return (mainClass != null && + mainClass.equals("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2")); + } + + private class OptionParser extends SparkSubmitOptionParser { @Override From f4af92550cb90e47a12d4625fa615dd2b1587d42 Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Sun, 3 May 2015 11:42:02 -0700 Subject: [PATCH 35/91] [SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark Author: Omede Firouz Author: Omede Closes #5601 from oefirouz/paramgrid and squashes the following commits: c9e2481 [Omede Firouz] Make test a doctest 9a8ce22 [Omede] Fix linter issues 8b8a6d2 [Omede Firouz] [SPARK-7022][PySpark][ML] Add ML.Tuning.ParamGridBuilder to PySpark --- python/pyspark/ml/tuning.py | 94 +++++++++++++++++++++++++++++++++++++ python/run-tests | 1 + 2 files changed, 95 insertions(+) create mode 100644 python/pyspark/ml/tuning.py diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py new file mode 100644 index 0000000000000..a383bd0c0d26f --- /dev/null +++ b/python/pyspark/ml/tuning.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ['ParamGridBuilder'] + + +class ParamGridBuilder(object): + """ + Builder for a param grid used in grid search-based model selection. + + >>> from classification import LogisticRegression + >>> lr = LogisticRegression() + >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ + .baseOn([lr.predictionCol, 'p']) \ + .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ + .addGrid(lr.maxIter, [1, 5]) \ + .addGrid(lr.featuresCol, ['f']) \ + .build() + >>> expected = [ \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] + >>> fail_count = 0 + >>> for e in expected: + ... if e not in output: + ... fail_count += 1 + >>> if len(expected) != len(output): + ... fail_count += 1 + >>> fail_count + 0 + """ + + def __init__(self): + self._param_grid = {} + + def addGrid(self, param, values): + """ + Sets the given parameters in this grid to fixed values. + """ + self._param_grid[param] = values + + return self + + def baseOn(self, *args): + """ + Sets the given parameters in this grid to fixed values. + Accepts either a parameter dictionary or a list of (parameter, value) pairs. + """ + if isinstance(args[0], dict): + self.baseOn(*args[0].items()) + else: + for (param, value) in args: + self.addGrid(param, [value]) + + return self + + def build(self): + """ + Builds and returns all combinations of parameters specified + by the param grid. + """ + param_maps = [{}] + for (param, values) in self._param_grid.items(): + new_param_maps = [] + for value in values: + for old_map in param_maps: + copied_map = old_map.copy() + copied_map[param] = value + new_param_maps.append(copied_map) + param_maps = new_param_maps + + return param_maps + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/run-tests b/python/run-tests index 88b63b84fdc27..0e0eee3564e7c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -98,6 +98,7 @@ function run_ml_tests() { echo "Run ml tests ..." run_test "pyspark/ml/feature.py" run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tuning.py" run_test "pyspark/ml/tests.py" } From daa70bf135f23381f5f410aa95a1c0e5a2888568 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 3 May 2015 13:12:50 -0700 Subject: [PATCH 36/91] [SPARK-6907] [SQL] Isolated client for HiveMetastore This PR adds initial support for loading multiple versions of Hive in a single JVM and provides a common interface for extracting metadata from the `HiveMetastoreClient` for a given version. This is accomplished by creating an isolated `ClassLoader` that operates according to the following rules: - __Shared Classes__: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` allowing the results of calls to the `ClientInterface` to be visible externally. - __Hive Classes__: new instances are loaded from `execJars`. These classes are not accessible externally due to their custom loading. - __Barrier Classes__: Classes such as `ClientWrapper` are defined in Spark but must link to a specific version of Hive. As a result, the bytecode is acquired from the Spark `ClassLoader` but a new copy is created for each instance of `IsolatedClientLoader`. This new instance is able to see a specific version of hive without using reflection where ever hive is consistent across versions. Since this is a unique instance, it is not visible externally other than as a generic `ClientInterface`, unless `isolationOn` is set to `false`. In addition to the unit tests, I have also tested this locally against mysql instances of the Hive Metastore. I've also successfully ported Spark SQL to run with this client, but due to the size of the changes, that will come in a follow-up PR. By default, Hive jars are currently downloaded from Maven automatically for a given version to ease packaging and testing. However, there is also support for specifying their location manually for deployments without internet. Author: Michael Armbrust Closes #5851 from marmbrus/isolatedClient and squashes the following commits: c72f6ac [Michael Armbrust] rxins comments 1e271fa [Michael Armbrust] [SPARK-6907][SQL] Isolated client for HiveMetastore --- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/sql/catalyst/analysis/Catalog.scala | 2 + .../spark/sql/catalyst/util/package.scala | 37 +- .../sql/hive/client/ClientInterface.scala | 149 +++++++ .../spark/sql/hive/client/ClientWrapper.scala | 395 ++++++++++++++++++ .../hive/client/IsolatedClientLoader.scala | 172 ++++++++ .../sql/hive/client/ReflectionMagic.scala | 200 +++++++++ .../spark/sql/hive/client/package.scala | 33 ++ .../spark/sql/hive/client/VersionsSuite.scala | 105 +++++ 9 files changed, 1088 insertions(+), 7 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 42b5d41b7b526..8a0327984e195 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -701,7 +701,7 @@ object SparkSubmit { } /** Provides utility functions to be used inside SparkSubmit. */ -private[deploy] object SparkSubmitUtils { +private[spark] object SparkSubmitUtils { // Exposed for testing var printStream = SparkSubmit.printStream diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index b2f8157a1a61f..18c24b651921a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} */ class NoSuchTableException extends Exception +class NoSuchDatabaseException extends Exception + /** * An interface for looking up relations by name. Used by an [[Analyzer]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index c86214a2aa944..9d613a940ee86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -17,12 +17,31 @@ package org.apache.spark.sql.catalyst -import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} +import java.io._ import org.apache.spark.util.Utils package object util { + /** Silences output to stderr or stdout for the duration of f */ + def quietly[A](f: => A): A = { + val origErr = System.err + val origOut = System.out + try { + System.setErr(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + System.setOut(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + + f + } finally { + System.setErr(origErr) + System.setOut(origOut) + } + } + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream @@ -42,10 +61,9 @@ package object util { new String(outStream.toByteArray, encoding) } - def resourceToString( - resource:String, - encoding: String = "UTF-8", - classLoader: ClassLoader = Utils.getSparkClassLoader): String = { + def resourceToBytes( + resource: String, + classLoader: ClassLoader = Utils.getSparkClassLoader): Array[Byte] = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -61,7 +79,14 @@ package object util { finally { inStream.close() } - new String(outStream.toByteArray, encoding) + outStream.toByteArray + } + + def resourceToString( + resource:String, + encoding: String = "UTF-8", + classLoader: ClassLoader = Utils.getSparkClassLoader): String = { + new String(resourceToBytes(resource, classLoader), encoding) } def stringToFile(file: File, str: String): File = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala new file mode 100644 index 0000000000000..a863aa77cb7e0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} + +case class HiveDatabase( + name: String, + location: String) + +abstract class TableType { val name: String } +case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } +case object IndexTable extends TableType { override val name = "INDEX_TABLE" } +case object ManagedTable extends TableType { override val name = "MANAGED_TABLE" } +case object VirtualView extends TableType { override val name = "VIRTUAL_VIEW" } + +case class HiveStorageDescriptor( + location: String, + inputFormat: String, + outputFormat: String, + serde: String) + +case class HivePartition( + values: Seq[String], + storage: HiveStorageDescriptor) + +case class HiveColumn(name: String, hiveType: String, comment: String) +case class HiveTable( + specifiedDatabase: Option[String], + name: String, + schema: Seq[HiveColumn], + partitionColumns: Seq[HiveColumn], + properties: Map[String, String], + serdeProperties: Map[String, String], + tableType: TableType, + location: Option[String] = None, + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) { + + @transient + private[client] var client: ClientInterface = _ + + private[client] def withClient(ci: ClientInterface): this.type = { + client = ci + this + } + + def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) + + def isPartitioned: Boolean = partitionColumns.nonEmpty + + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + + // Hive does not support backticks when passing names to the client. + def qualifiedName: String = s"$database.$name" +} + +/** + * An externally visible interface to the Hive client. This interface is shared across both the + * internal and external classloaders for a given version of Hive and thus must expose only + * shared classes. + */ +trait ClientInterface { + /** + * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will + * result in one string. + */ + def runSqlHive(sql: String): Seq[String] + + /** Returns the names of all tables in the given database. */ + def listTables(dbName: String): Seq[String] + + /** Returns the name of the active database. */ + def currentDatabase: String + + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ + def getDatabase(name: String): HiveDatabase = { + getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) + } + + /** Returns the metadata for a given database, or None if it doesn't exist. */ + def getDatabaseOption(name: String): Option[HiveDatabase] + + /** Returns the specified table, or throws [[NoSuchTableException]]. */ + def getTable(dbName: String, tableName: String): HiveTable = { + getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException) + } + + /** Returns the metadata for the specified table or None if it doens't exist. */ + def getTableOption(dbName: String, tableName: String): Option[HiveTable] + + /** Creates a table with the given metadata. */ + def createTable(table: HiveTable): Unit + + /** Updates the given table with new metadata. */ + def alterTable(table: HiveTable): Unit + + /** Creates a new database with the given name. */ + def createDatabase(database: HiveDatabase): Unit + + /** Returns all partitions for the given table. */ + def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + + /** Loads a static partition into an existing table. */ + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + /** Loads data into an existing table. */ + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + /** Loads new dynamic partitions into an existing table. */ + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + /** Used for testing only. Removes all metadata from this instance of Hive. */ + def reset(): Unit +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala new file mode 100644 index 0000000000000..ea52fea037f1f --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.{BufferedReader, InputStreamReader, File, PrintStream} +import java.net.URI +import java.util.{ArrayList => JArrayList} + +import scala.collection.JavaConversions._ +import scala.language.reflectiveCalls + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.api.Database +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.metadata +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.Driver + +import org.apache.spark.Logging +import org.apache.spark.sql.execution.QueryExecutionException + + +/** + * A class that wraps the HiveClient and converts its responses to externally visible classes. + * Note that this class is typically loaded with an internal classloader for each instantiation, + * allowing it to interact directly with a specific isolated version of Hive. Loading this class + * with the isolated classloader however will result in it only being visible as a ClientInterface, + * not a ClientWrapper. + * + * This class needs to interact with multiple versions of Hive, but will always be compiled with + * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility + * must use reflection after matching on `version`. + * + * @param version the version of hive used when pick function calls that are not compatible. + * @param config a collection of configuration options that will be added to the hive conf before + * opening the hive client. + */ +class ClientWrapper( + version: HiveVersion, + config: Map[String, String]) + extends ClientInterface + with Logging + with ReflectionMagic { + + private val conf = new HiveConf(classOf[SessionState]) + config.foreach { case (k, v) => + logDebug(s"Hive Config: $k=$v") + conf.set(k, v) + } + + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + private val outputBuffer = new java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](10240) + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.size + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while(line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } + } + + val state = { + val original = Thread.currentThread().getContextClassLoader + Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + val ret = try { + val newState = new SessionState(conf) + SessionState.start(newState) + newState.out = new PrintStream(outputBuffer, true, "UTF-8") + newState.err = new PrintStream(outputBuffer, true, "UTF-8") + newState + } finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + private val client = Hive.get(conf) + + /** + * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. + */ + private def withHiveState[A](f: => A): A = synchronized { + val original = Thread.currentThread().getContextClassLoader + Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + Hive.set(client) + version match { + case hive.v12 => + classOf[SessionState] + .callStatic[SessionState, SessionState]("start", state) + case hive.v13 => + classOf[SessionState] + .callStatic[SessionState, SessionState]("setCurrentSessionState", state) + } + val ret = try f finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + override def currentDatabase: String = withHiveState { + state.getCurrentDatabase + } + + override def createDatabase(database: HiveDatabase): Unit = withHiveState { + client.createDatabase( + new Database( + database.name, + "", + new File(database.location).toURI.toString, + new java.util.HashMap), + true) + } + + override def getDatabaseOption(name: String): Option[HiveDatabase] = withHiveState { + Option(client.getDatabase(name)).map { d => + HiveDatabase( + name = d.getName, + location = d.getLocationUri) + } + } + + override def getTableOption( + dbName: String, + tableName: String): Option[HiveTable] = withHiveState { + + logDebug(s"Looking up $dbName.$tableName") + + val hiveTable = Option(client.getTable(dbName, tableName, false)) + val converted = hiveTable.map { h => + + HiveTable( + name = h.getTableName, + specifiedDatabase = Option(h.getDbName), + schema = h.getCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, + tableType = ManagedTable, // TODO + location = version match { + case hive.v12 => Option(h.call[URI]("getDataLocation")).map(_.toString) + case hive.v13 => Option(h.call[Path]("getDataLocation")).map(_.toString) + }, + inputFormat = Option(h.getInputFormatClass).map(_.getName), + outputFormat = Option(h.getOutputFormatClass).map(_.getName), + serde = Option(h.getSerializationLib)).withClient(this) + } + converted + } + + private def toInputFormat(name: String) = + Class.forName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + + private def toOutputFormat(name: String) = + Class.forName(name) + .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] + + private def toQlTable(table: HiveTable): metadata.Table = { + val qlTable = new metadata.Table(table.database, table.name) + + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + qlTable.setPartCols( + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } + table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } + version match { + case hive.v12 => + table.location.map(new URI(_)).foreach(u => qlTable.call[URI, Unit]("setDataLocation", u)) + case hive.v13 => + table.location + .map(new org.apache.hadoop.fs.Path(_)) + .foreach(qlTable.call[Path, Unit]("setDataLocation", _)) + } + table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) + table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) + table.serde.foreach(qlTable.setSerializationLib) + + qlTable + } + + override def createTable(table: HiveTable): Unit = withHiveState { + val qlTable = toQlTable(table) + client.createTable(qlTable) + } + + override def alterTable(table: HiveTable): Unit = withHiveState { + val qlTable = toQlTable(table) + client.alterTable(table.qualifiedName, qlTable) + } + + override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + val qlPartitions = version match { + case hive.v12 => + client.call[metadata.Table, Set[metadata.Partition]]("getAllPartitionsForPruner", qlTable) + case hive.v13 => + client.call[metadata.Table, Set[metadata.Partition]]("getAllPartitionsOf", qlTable) + } + qlPartitions.map(_.getTPartition).map { p => + HivePartition( + values = Option(p.getValues).map(_.toSeq).getOrElse(Seq.empty), + storage = HiveStorageDescriptor( + location = p.getSd.getLocation, + inputFormat = p.getSd.getInputFormat, + outputFormat = p.getSd.getOutputFormat, + serde = p.getSd.getSerdeInfo.getSerializationLib)) + }.toSeq + } + + override def listTables(dbName: String): Seq[String] = withHiveState { + client.getAllTables + } + + /** + * Runs the specified SQL query using Hive. + */ + override def runSqlHive(sql: String): Seq[String] = { + val maxResults = 100000 + val results = runHive(sql, maxResults) + // It is very confusing when you only get back some of the results... + if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") + results + } + + /** + * Execute the command using Hive and return the results as a sequence. Each element + * in the sequence is one row. + */ + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { + logDebug(s"Running hiveql '$cmd'") + if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + try { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + val proc: CommandProcessor = version match { + case hive.v12 => + classOf[CommandProcessorFactory] + .callStatic[String, HiveConf, CommandProcessor]("get", cmd_1, conf) + case hive.v13 => + classOf[CommandProcessorFactory] + .callStatic[Array[String], HiveConf, CommandProcessor]("get", Array(tokens(0)), conf) + } + + proc match { + case driver: Driver => + val response: CommandProcessorResponse = driver.run(cmd) + // Throw an exception if there is an error in query processing. + if (response.getResponseCode != 0) { + driver.close() + throw new QueryExecutionException(response.getErrorMessage) + } + driver.setMaxRows(maxRows) + + val results = version match { + case hive.v12 => + val res = new JArrayList[String] + driver.call[JArrayList[String], Boolean]("getResults", res) + res.toSeq + case hive.v13 => + val res = new JArrayList[Object] + driver.call[JArrayList[Object], Boolean]("getResults", res) + res.map { r => + r match { + case s: String => s + case a: Array[Object] => a(0).asInstanceOf[String] + } + } + } + driver.close() + results + + case _ => + if (state.out != null) { + state.out.println(tokens(0) + " " + cmd_1) + } + Seq(proc.run(cmd_1).getResponseCode.toString) + } + } catch { + case e: Exception => + logError( + s""" + |====================== + |HIVE FAILURE OUTPUT + |====================== + |${outputBuffer.toString} + |====================== + |END HIVE FAILURE OUTPUT + |====================== + """.stripMargin) + throw e + } + } + + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { + + client.loadPartition( + new Path(loadPath), // TODO: Use URI + tableName, + partSpec, + replace, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = withHiveState { + client.loadTable( + new Path(loadPath), + tableName, + replace, + holdDDLTime) + } + + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = withHiveState { + client.loadDynamicPartitions( + new Path(loadPath), + tableName, + partSpec, + replace, + numDP, + holdDDLTime, + listBucketingEnabled) + } + + def reset(): Unit = withHiveState { + client.getAllTables("default").foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).foreach { index => + client.dropIndex("default", t, index.getIndexName, true) + } + if (!table.isIndexTable) { + client.dropTable("default", t) + } + } + client.getAllDatabases.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala new file mode 100644 index 0000000000000..710dbca6e3c66 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File +import java.net.URLClassLoader +import java.util + +import scala.language.reflectiveCalls +import scala.util.Try + +import org.apache.commons.io.{FileUtils, IOUtils} + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkSubmitUtils + +import org.apache.spark.sql.catalyst.util.quietly + +/** Factory for `IsolatedClientLoader` with specific versions of hive. */ +object IsolatedClientLoader { + /** + * Creates isolated Hive client loaders by downloading the requested version from maven. + */ + def forVersion( + version: String, + config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + val resolvedVersion = hiveVersion(version) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + new IsolatedClientLoader(hiveVersion(version), files, config) + } + + def hiveVersion(version: String): HiveVersion = version match { + case "12" | "0.12" | "0.12.0" => hive.v12 + case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 + } + + private def downloadVersion(version: HiveVersion): Seq[File] = { + val hiveArtifacts = + (Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") ++ + (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil)) + .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+ + "com.google.guava:guava:14.0.1" :+ + "org.apache.hadoop:hadoop-client:2.4.0" :+ + "mysql:mysql-connector-java:5.1.12" + + val classpath = quietly { + SparkSubmitUtils.resolveMavenCoordinates( + hiveArtifacts.mkString(","), + Some("http://www.datanucleus.org/downloads/maven2"), + None) + } + val allFiles = classpath.split(",").map(new File(_)).toSet + + // TODO: Remove copy logic. + val tempDir = File.createTempFile("hive", "v" + version.toString) + tempDir.delete() + tempDir.mkdir() + + allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) + tempDir.listFiles() + } + + private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[File]] +} + +/** + * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` + * allowing the results of calls to the `ClientInterface` to be visible externally. + * - Hive classes: new instances are loaded from `execJars`. These classes are not + * accessible externally due to their custom loading. + * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * This new instance is able to see a specific version of hive without using reflection. Since + * this is a unique instance, it is not visible externally other than as a generic + * `ClientInterface`, unless `isolationOn` is set to `false`. + * + * @param version The version of hive on the classpath. used to pick specific function signatures + * that are not compatibile accross versions. + * @param execJars A collection of jar files that must include hive and hadoop. + * @param config A set of options that will be added to the HiveConf of the constructed client. + * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be + * true unless loading the version of hive that is on Sparks classloader. + * @param rootClassLoader The system root classloader. Must not know about hive classes. + * @param baseClassLoader The spark classloader that is used to load shared classes. + * + */ +class IsolatedClientLoader( + val version: HiveVersion, + val execJars: Seq[File] = Seq.empty, + val config: Map[String, String] = Map.empty, + val isolationOn: Boolean = true, + val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, + val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader) + extends Logging { + + // Check to make sure that the root classloader does not know about Hive. + assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) + + /** All jars used by the hive specific classloader. */ + protected def allJars = execJars.map(_.toURI.toURL).toArray + + protected def isSharedClass(name: String): Boolean = + name.contains("slf4j") || + name.contains("log4j") || + name.startsWith("org.apache.spark.") || + name.startsWith("scala.") || + name.startsWith("com.google") || + name.startsWith("java.lang.") || + name.startsWith("java.net") + + /** True if `name` refers to a spark class that must see specific version of Hive. */ + protected def isBarrierClass(name: String): Boolean = + name.startsWith("org.apache.spark.sql.hive.execution.PairSerDe") || + name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[ReflectionMagic].getName) + + protected def classToPath(name: String): String = + name.replaceAll("\\.", "/") + ".class" + + /** The classloader that is used to load an isolated version of Hive. */ + protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + val loaded = findLoadedClass(name) + if (loaded == null) doLoadClass(name, resolve) else loaded + } + + def doLoadClass(name: String, resolve: Boolean): Class[_] = { + val classFileName = name.replaceAll("\\.", "/") + ".class" + if (isBarrierClass(name) && isolationOn) { + val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) + logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") + defineClass(name, bytes, 0, bytes.length) + } else if (!isSharedClass(name)) { + logDebug(s"hive class: $name - ${getResource(classToPath(name))}") + super.loadClass(name, resolve) + } else { + logDebug(s"shared class: $name") + baseClassLoader.loadClass(name) + } + } + } + + // Pre-reflective instantiation setup. + logDebug("Initializing the logger to avoid disaster...") + Thread.currentThread.setContextClassLoader(classLoader) + + /** The isolated client interface to Hive. */ + val client: ClientInterface = try { + classLoader + .loadClass(classOf[ClientWrapper].getName) + .getConstructors.head + .newInstance(version, config) + .asInstanceOf[ClientInterface] + } finally { + Thread.currentThread.setContextClassLoader(baseClassLoader) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala new file mode 100644 index 0000000000000..90d03049356b5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.reflect._ + +/** + * Provides implicit functions on any object for calling methods reflectively. + */ +protected trait ReflectionMagic { + /** code for InstanceMagic + println( + (1 to 22).map { n => + def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") + val types = repeat(n => s"A$n <: AnyRef : ClassTag") + val inArgs = repeat(n => s"a$n: A$n") + val erasure = repeat(n => s"classTag[A$n].erasure") + val outArgs = repeat(n => s"a$n") + s"""|def call[$types, R](name: String, $inArgs): R = { + | clazz.getMethod(name, $erasure).invoke(a, $outArgs).asInstanceOf[R] + |}""".stripMargin + }.mkString("\n") + ) + */ + + // scalastyle:off + protected implicit class InstanceMagic(a: Any) { + private val clazz = a.getClass + + def call[R](name: String): R = { + clazz.getMethod(name).invoke(a).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { + clazz.getMethod(name, classTag[A1].erasure).invoke(a, a1).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(a, a1, a2).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(a, a1, a2, a3).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(a, a1, a2, a3, a4).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(a, a1, a2, a3, a4, a5).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(a, a1, a2, a3, a4, a5, a6).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] + } + def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { + clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] + } + } + + /** code for StaticMagic + println( + (1 to 22).map { n => + def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") + val types = repeat(n => s"A$n <: AnyRef : ClassTag") + val inArgs = repeat(n => s"a$n: A$n") + val erasure = repeat(n => s"classTag[A$n].erasure") + val outArgs = repeat(n => s"a$n") + s"""|def callStatic[$types, R](name: String, $inArgs): R = { + | c.getDeclaredMethod(name, $erasure).invoke(c, $outArgs).asInstanceOf[R] + |}""".stripMargin + }.mkString("\n") + ) + */ + + protected implicit class StaticMagic(c: Class[_]) { + def callStatic[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { + c.getDeclaredMethod(name, classTag[A1].erasure).invoke(c, a1).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(c, a1, a2).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(c, a1, a2, a3).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(c, a1, a2, a3, a4).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(c, a1, a2, a3, a4, a5).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(c, a1, a2, a3, a4, a5, a6).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] + } + def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { + c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] + } + } + // scalastyle:on +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala new file mode 100644 index 0000000000000..7db9200d47440 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +/** Support for interacting with different versions of the HiveMetastoreClient */ +package object client { + private[client] abstract class HiveVersion(val fullVersion: String, val hasBuiltinsJar: Boolean) + + // scalastyle:off + private[client] object hive { + case object v10 extends HiveVersion("0.10.0", true) + case object v11 extends HiveVersion("0.11.0", false) + case object v12 extends HiveVersion("0.12.0", false) + case object v13 extends HiveVersion("0.13.1", false) + } + // scalastyle:on + +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala new file mode 100644 index 0000000000000..81e77ba257bf1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.util.Utils +import org.scalatest.FunSuite + +class VersionsSuite extends FunSuite with Logging { + val testType = "derby" + + private def buildConf() = { + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() + metastorePath.delete() + Map( + "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", + "hive.metastore.warehouse.dir" -> warehousePath.toString) + } + + test("success sanity check") { + val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val db = new HiveDatabase("default", "") + badClient.createDatabase(db) + } + + private def getNestedMessages(e: Throwable): String = { + var causes = "" + var lastException = e + while (lastException != null) { + causes += lastException.toString + "\n" + lastException = lastException.getCause + } + causes + } + + // Its actually pretty easy to mess things up and have all of your tests "pass" by accidentally + // connecting to an auto-populated, in-process metastore. Let's make sure we are getting the + // versions right by forcing a known compatibility failure. + // TODO: currently only works on mysql where we manually create the schema... + ignore("failure sanity check") { + val e = intercept[Throwable] { + val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + } + assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") + } + + private val versions = Seq("12", "13") + + private var client: ClientInterface = null + + versions.foreach { version => + test(s"$version: listTables") { + client = null + client = IsolatedClientLoader.forVersion(version, buildConf()).client + client.listTables("default") + } + + test(s"$version: createDatabase") { + val db = HiveDatabase("default", "") + client.createDatabase(db) + } + + test(s"$version: createTable") { + val table = + HiveTable( + specifiedDatabase = Option("default"), + name = "src", + schema = Seq(HiveColumn("key", "int", "")), + partitionColumns = Seq.empty, + properties = Map.empty, + serdeProperties = Map.empty, + tableType = ManagedTable, + location = None, + inputFormat = + Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), + outputFormat = + Some(classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = + Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName())) + + client.createTable(table) + } + + test(s"$version: getTable") { + client.getTable("default", "src") + } + } +} From 9e25b09f8809378777ae8bbe75dca12d2c45ff4c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 3 May 2015 21:22:31 +0100 Subject: [PATCH 37/91] [SPARK-7302] [DOCS] SPARK building documentation still mentions building for yarn 0.23 Remove references to Hadoop 0.23 CC tgravescs Is this what you had in mind? basically all refs to 0.23? We don't support YARN 0.23, but also don't support Hadoop 0.23 anymore AFAICT. There are no builds or releases for it. In fact, on a related note, refs to CDH3 (Hadoop 0.20.2) should be removed as this certainly isn't supported either. Author: Sean Owen Closes #5863 from srowen/SPARK-7302 and squashes the following commits: 42f5d1e [Sean Owen] Remove CDH3 (Hadoop 0.20.2) refs too dad02e3 [Sean Owen] Remove references to Hadoop 0.23 --- docs/building-spark.md | 4 ---- docs/hadoop-third-party-distributions.md | 3 --- make-distribution.sh | 2 +- pom.xml | 14 -------------- 4 files changed, 1 insertion(+), 22 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index ea79c5bc276d3..287fcd3c4034f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -66,7 +66,6 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro Hadoop versionProfile required - 0.23.xhadoop-0.23 1.x to 2.1.x(none) 2.2.xhadoop-2.2 2.3.xhadoop-2.3 @@ -82,9 +81,6 @@ mvn -Dhadoop.version=1.2.1 -DskipTests clean package # Cloudera CDH 4.2.0 with MapReduce v1 mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -DskipTests clean package - -# Apache Hadoop 0.23.x -mvn -Phadoop-0.23 -Dhadoop.version=0.23.7 -DskipTests clean package {% endhighlight %} You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index 87dcc58feb494..96bd69ca3b33b 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -29,9 +29,6 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors. ReleaseVersion code CDH 4.X.X (YARN mode)2.0.0-cdh4.X.X CDH 4.X.X2.0.0-mr1-cdh4.X.X - CDH 3u60.20.2-cdh3u6 - CDH 3u50.20.2-cdh3u5 - CDH 3u40.20.2-cdh3u4 diff --git a/make-distribution.sh b/make-distribution.sh index 92177e19fe6be..1bfa9acb1fe6e 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-0.23, hdaoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) diff --git a/pom.xml b/pom.xml index 4313f940036c8..de18741feae3a 100644 --- a/pom.xml +++ b/pom.xml @@ -1614,20 +1614,6 @@ http://hadoop.apache.org/docs/ra.b.c/hadoop-project-dist/hadoop-common/dependency-analysis.html --> - - hadoop-0.23 - - - - org.apache.avro - avro - - - - 0.23.10 - - - hadoop-2.2 From 1ffa8cb91f8badf12a8aa190dc25920715a00db7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 May 2015 18:06:48 -0700 Subject: [PATCH 38/91] [SPARK-7329] [MLLIB] simplify ParamGridBuilder impl as suggested by justinuang on #5601. Author: Xiangrui Meng Closes #5873 from mengxr/SPARK-7329 and squashes the following commits: d08f9cf [Xiangrui Meng] simplify tests b7a7b9b [Xiangrui Meng] simplify grid build --- python/pyspark/ml/tuning.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a383bd0c0d26f..1773ab5bdcdb1 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,6 +15,8 @@ # limitations under the License. # +import itertools + __all__ = ['ParamGridBuilder'] @@ -37,14 +39,10 @@ class ParamGridBuilder(object): {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] - >>> fail_count = 0 - >>> for e in expected: - ... if e not in output: - ... fail_count += 1 - >>> if len(expected) != len(output): - ... fail_count += 1 - >>> fail_count - 0 + >>> len(output) == len(expected) + True + >>> all([m in expected for m in output]) + True """ def __init__(self): @@ -76,17 +74,9 @@ def build(self): Builds and returns all combinations of parameters specified by the param grid. """ - param_maps = [{}] - for (param, values) in self._param_grid.items(): - new_param_maps = [] - for value in values: - for old_map in param_maps: - copied_map = old_map.copy() - copied_map[param] = value - new_param_maps.append(copied_map) - param_maps = new_param_maps - - return param_maps + keys = self._param_grid.keys() + grid_values = self._param_grid.values() + return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] if __name__ == "__main__": From 9646018bb4466433521b4e602b808f16e8d0ffdb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 3 May 2015 21:44:39 -0700 Subject: [PATCH 39/91] [SPARK-7241] Pearson correlation for DataFrames submitting this PR from a phone, excuse the brevity. adds Pearson correlation to Dataframes, reusing the covariance calculation code cc mengxr rxin Author: Burak Yavuz Closes #5858 from brkyvz/df-corr and squashes the following commits: 285b838 [Burak Yavuz] addressed comments v2.0 d10babb [Burak Yavuz] addressed comments v0.2 4b74b24 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-corr 4fe693b [Burak Yavuz] addressed comments v0.1 a682d06 [Burak Yavuz] ready for PR --- python/pyspark/sql/dataframe.py | 26 +++++++++ python/pyspark/sql/tests.py | 6 ++ .../spark/sql/DataFrameStatFunctions.scala | 26 +++++++++ .../sql/execution/stat/StatFunctions.scala | 58 ++++++++++++------- .../apache/spark/sql/JavaDataFrameSuite.java | 7 +++ .../apache/spark/sql/DataFrameStatSuite.scala | 33 +++++++++-- 6 files changed, 130 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ddcff8fcdf98..aac5b8c4c5770 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -875,6 +875,27 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def corr(self, col1, col2, method=None): + """ + Calculates the correlation of two columns of a DataFrame as a double value. Currently only + supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + :param method: The correlation method. Currently only supports "pearson" + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + if not method: + method = "pearson" + if not method == "pearson": + raise ValueError("Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + return self._jdf.stat().corr(col1, col2, method) + def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a @@ -1359,6 +1380,11 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df + def corr(self, col1, col2, method=None): + return self.df.corr(col1, col2, method) + + corr.__doc__ = DataFrame.corr.__doc__ + def cov(self, col1, col2): return self.df.cov(col1, col2) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 613efc0ac029d..d652c302a54ba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -394,6 +394,12 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_corr(self): + import math + df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() + corr = df.stat.corr("a", "b") + self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() cov = df.stat.cov("a", "b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e8fa82947759b..903532105284e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -27,6 +27,32 @@ import org.apache.spark.sql.execution.stat._ @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * MLlib's Statistics. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + */ + def corr(col1: String, col2: String, method: String): Double = { + require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) + } + + /** + * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + */ + def corr(col1: String, col2: String): Double = { + corr(col1, col2, "pearson") + } + /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d4a94c24d9866..67b48e58b17ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -23,29 +23,43 @@ import org.apache.spark.sql.types.{DoubleType, NumericType} private[sql] object StatFunctions { + /** Calculate the Pearson Correlation Coefficient for the given columns */ + private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols) + counts.Ck / math.sqrt(counts.MkX * counts.MkY) + } + /** Helper class to simplify tracking and merging counts. */ private class CovarianceCounter extends Serializable { - var xAvg = 0.0 - var yAvg = 0.0 - var Ck = 0.0 - var count = 0L + var xAvg = 0.0 // the mean of all examples seen so far in col1 + var yAvg = 0.0 // the mean of all examples seen so far in col2 + var Ck = 0.0 // the co-moment after k examples + var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 + var MkY = 0.0 // sum of squares of differences from the (current) mean for col1 + var count = 0L // count of observed examples // add an example to the calculation def add(x: Double, y: Double): this.type = { - val oldX = xAvg + val deltaX = x - xAvg + val deltaY = y - yAvg count += 1 - xAvg += (x - xAvg) / count - yAvg += (y - yAvg) / count - Ck += (y - yAvg) * (x - oldX) + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) this } // merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance def merge(other: CovarianceCounter): this.type = { val totalCount = count + other.count - Ck += other.Ck + - (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count + val deltaX = xAvg - other.xAvg + val deltaY = yAvg - other.yAvg + Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count xAvg = (xAvg * count + other.xAvg * other.count) / totalCount yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count + MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count count = totalCount this } @@ -53,13 +67,7 @@ private[sql] object StatFunctions { def cov: Double = Ck / (count - 1) } - /** - * Calculate the covariance of two numerical columns of a DataFrame. - * @param df The DataFrame - * @param cols the column names - * @return the covariance of the two columns. - */ - private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { require(cols.length == 2, "Currently cov supports calculating the covariance " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => @@ -68,13 +76,23 @@ private[sql] object StatFunctions { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, combOp = (baseCounter, other) => { baseCounter.merge(other) - }) + }) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @param cols the column names + * @return the covariance of the two columns. + */ + private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + val counts = collectStatisticalData(df, cols) counts.cov } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 96fe66d0b84a6..78e847239f405 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -187,6 +187,13 @@ public void testFrequentItems() { Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } + @Test + public void testCorrelation() { + DataFrame df = context.table("testData2"); + Double pearsonCorr = df.stat().corr("a", "b", "pearson"); + Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + } + @Test public void testCovariance() { DataFrame df = context.table("testData2"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 4f5a2ff696789..06764d2a122f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite { def toLetter(i: Int): String = (i + 97).toChar.toString test("Frequent Items") { - val rows = Array.tabulate(1000) { i => + val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } - val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") + val df = rows.toDF("numbers", "letters", "negDoubles") val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head @@ -43,19 +43,40 @@ class DataFrameStatSuite extends FunSuite { val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) + } + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.stat.corr("a", "b", "pearson") + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.stat.corr("a", "c", "pearson") + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.stat.corr("a", "b", "pearson") + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) } test("covariance") { - val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) - val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") val results = df.stat.cov("singles", "doubles") - assert(math.abs(results - 55.0 / 3) < 1e-6) + assert(math.abs(results - 55.0 / 3) < 1e-12) intercept[IllegalArgumentException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } val decimalRes = decimalData.stat.cov("a", "b") - assert(math.abs(decimalRes) < 1e-6) + assert(math.abs(decimalRes) < 1e-12) } } From 3539cb7d20f5f878132407ec3b854011b183b2ad Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 4 May 2015 00:06:25 -0700 Subject: [PATCH 40/91] [SPARK-5563] [MLLIB] LDA with online variational inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA: https://issues.apache.org/jira/browse/SPARK-5563 The PR contains the implementation for [Online LDA] (https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) based on the research of Matt Hoffman and David M. Blei, which provides an efficient option for LDA users. Major advantages for the algorithm are the stream compatibility and economic time/memory consumption due to the corpus split. For more details, please refer to the jira. Online LDA can act as a fast option for LDA, and will be especially helpful for the users who needs a quick result or with large corpus. Correctness test. I have tested current PR with https://github.com/Blei-Lab/onlineldavb and the results are identical. I've uploaded the result and code to https://github.com/hhbyyh/LDACrossValidation. Author: Yuhao Yang Author: Joseph K. Bradley Closes #4419 from hhbyyh/ldaonline and squashes the following commits: 1045eec [Yuhao Yang] Merge pull request #2 from jkbradley/hhbyyh-ldaonline2 cf376ff [Joseph K. Bradley] For private vars needed for testing, I made them private and added accessors. Java doesn’t understand package-private tags, so this minimizes the issues Java users might encounter. 6149ca6 [Yuhao Yang] fix for setOptimizer cf0007d [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline 54cf8da [Yuhao Yang] some style change 68c2318 [Yuhao Yang] add a java ut 4041723 [Yuhao Yang] add ut 138bfed [Yuhao Yang] Merge pull request #1 from jkbradley/hhbyyh-ldaonline-update 9e910d9 [Joseph K. Bradley] small fix 61d60df [Joseph K. Bradley] Minor cleanups: * Update *Concentration parameter documentation * EM Optimizer: createVertices() does not need to be a function * OnlineLDAOptimizer: typos in doc * Clean up the core code for online LDA (Scala style) a996a82 [Yuhao Yang] respond to comments b1178cf [Yuhao Yang] fit into the optimizer framework dbe3cff [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline 15be071 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline b29193b [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline d19ef55 [Yuhao Yang] change OnlineLDA to class 97b9e1a [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline e7bf3b0 [Yuhao Yang] move to seperate file f367cc9 [Yuhao Yang] change to optimization 8cb16a6 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline 62405cc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline 02d0373 [Yuhao Yang] fix style in comment f6d47ca [Yuhao Yang] Merge branch 'ldaonline' of https://github.com/hhbyyh/spark into ldaonline d86cdec [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline a570c9a [Yuhao Yang] use sample to pick up batch 4a3f27e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline e271eb1 [Yuhao Yang] remove non ascii 581c623 [Yuhao Yang] seperate API and adjust batch split 37af91a [Yuhao Yang] iMerge remote-tracking branch 'upstream/master' into ldaonline 20328d1 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline i aa365d1 [Yuhao Yang] merge upstream master 3a06526 [Yuhao Yang] merge with new example 0dd3947 [Yuhao Yang] kMerge remote-tracking branch 'upstream/master' into ldaonline 0d0f3ee [Yuhao Yang] replace random split with sliding fa408a8 [Yuhao Yang] ssMerge remote-tracking branch 'upstream/master' into ldaonline 45884ab [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s f41c5ca [Yuhao Yang] style fix 26dca1b [Yuhao Yang] style fix and make class private 043e786 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s Conflicts: mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala d640d9c [Yuhao Yang] online lda initial checkin --- .../apache/spark/mllib/clustering/LDA.scala | 65 ++-- .../spark/mllib/clustering/LDAOptimizer.scala | 320 ++++++++++++++++-- .../spark/mllib/clustering/JavaLDASuite.java | 38 ++- .../spark/mllib/clustering/LDASuite.scala | 89 ++++- 4 files changed, 438 insertions(+), 74 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 37bf88b73b911..c8daa2388e868 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -78,35 +78,29 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. */ - def getDocConcentration: Double = { - if (this.docConcentration == -1) { - (50.0 / k) + 1.0 - } else { - this.docConcentration - } - } + def getDocConcentration: Double = this.docConcentration /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' * distributions over topics ("theta"). * - * This is the parameter to a symmetric Dirichlet distribution. + * This is the parameter to a symmetric Dirichlet distribution, where larger values + * mean more smoothing (more regularization). * - * This value should be > 1.0, where larger values mean more smoothing (more regularization). * If set to -1, then docConcentration is set automatically. * (default = -1 = automatic) * - * Automatic setting of parameter: - * - For EM: default = (50 / k) + 1. - * - The 50/k is common in LDA libraries. - * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. - * - * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), - * but values in (0,1) are not yet supported. + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. */ def setDocConcentration(docConcentration: Double): this.type = { - require(docConcentration > 1.0 || docConcentration == -1.0, - s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") this.docConcentration = docConcentration this } @@ -126,13 +120,7 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ - def getTopicConcentration: Double = { - if (this.topicConcentration == -1) { - 1.1 - } else { - this.topicConcentration - } - } + def getTopicConcentration: Double = this.topicConcentration /** * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' @@ -143,21 +131,20 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * - * This value should be > 0.0. * If set to -1, then topicConcentration is set automatically. * (default = -1 = automatic) * - * Automatic setting of parameter: - * - For EM: default = 0.1 + 1. - * - The 0.1 gives a small amount of smoothing. - * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. - * - * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), - * but values in (0,1) are not yet supported. + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. */ def setTopicConcentration(topicConcentration: Double): this.type = { - require(topicConcentration > 1.0 || topicConcentration == -1.0, - s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") this.topicConcentration = topicConcentration this } @@ -223,14 +210,15 @@ class LDA private ( /** * Set the LDAOptimizer used to perform the actual calculation by algorithm name. - * Currently "em" is supported. + * Currently "em", "online" is supported. */ def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = optimizerName.toLowerCase match { case "em" => new EMLDAOptimizer + case "online" => new OnlineLDAOptimizer case other => - throw new IllegalArgumentException(s"Only em is supported but got $other.") + throw new IllegalArgumentException(s"Only em, online are supported but got $other.") } this } @@ -245,8 +233,7 @@ class LDA private ( * @return Inferred LDA model */ def run(documents: RDD[(Long, Vector)]): LDAModel = { - val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, - seed, checkpointInterval) + val state = ldaOptimizer.initialize(documents, this) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index ffd72a294c6c6..093aa0f315ab2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,13 +19,15 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize} +import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron} +import breeze.numerics.{digamma, exp, abs} +import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.Experimental import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector} import org.apache.spark.rdd.RDD /** @@ -35,7 +37,7 @@ import org.apache.spark.rdd.RDD * hold optimizer-specific parameters for users to set. */ @Experimental -trait LDAOptimizer{ +trait LDAOptimizer { /* DEVELOPERS NOTE: @@ -49,13 +51,7 @@ trait LDAOptimizer{ * Initializer for the optimizer. LDA passes the common parameters to the optimizer and * the internal structure can be initialized properly. */ - private[clustering] def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): LDAOptimizer + private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer private[clustering] def next(): LDAOptimizer @@ -80,12 +76,12 @@ trait LDAOptimizer{ * */ @Experimental -class EMLDAOptimizer extends LDAOptimizer{ +class EMLDAOptimizer extends LDAOptimizer { import LDA._ /** - * Following fields will only be initialized through initialState method + * The following fields will only be initialized through the initialize() method */ private[clustering] var graph: Graph[TopicCounts, TokenCount] = null private[clustering] var k: Int = 0 @@ -98,13 +94,23 @@ class EMLDAOptimizer extends LDAOptimizer{ /** * Compute bipartite term/doc graph. */ - private[clustering] override def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): LDAOptimizer = { + override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + + val docConcentration = lda.getDocConcentration + val topicConcentration = lda.getTopicConcentration + val k = lda.getK + + // Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + // but values in (0,1) are not yet supported. + require(docConcentration > 1.0 || docConcentration == -1.0, s"LDA docConcentration must be" + + s" > 1.0 (or -1 for auto) for EM Optimizer, but was set to $docConcentration") + require(topicConcentration > 1.0 || topicConcentration == -1.0, s"LDA topicConcentration " + + s"must be > 1.0 (or -1 for auto) for EM Optimizer, but was set to $topicConcentration") + + this.docConcentration = if (docConcentration == -1) (50.0 / k) + 1.0 else docConcentration + this.topicConcentration = if (topicConcentration == -1) 1.1 else topicConcentration + val randomSeed = lda.getSeed + // For each document, create an edge (Document -> Term) for each unique term in the document. val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => // Add edges for terms with non-zero counts. @@ -113,11 +119,9 @@ class EMLDAOptimizer extends LDAOptimizer{ } } - val vocabSize = docs.take(1).head._2.size - // Create vertices. // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { + val docTermVertices: RDD[(VertexId, TopicCounts)] = { val verticesTMP: RDD[(VertexId, TopicCounts)] = edges.mapPartitionsWithIndex { case (partIndex, partEdges) => val random = new Random(partIndex + randomSeed) @@ -130,22 +134,18 @@ class EMLDAOptimizer extends LDAOptimizer{ verticesTMP.reduceByKey(_ + _) } - val docTermVertices = createVertices() - // Partition such that edges are grouped by document this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) this.k = k - this.vocabSize = vocabSize - this.docConcentration = docConcentration - this.topicConcentration = topicConcentration - this.checkpointInterval = checkpointInterval + this.vocabSize = docs.take(1).head._2.size + this.checkpointInterval = lda.getCheckpointInterval this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) this.globalTopicTotals = computeGlobalTopicTotals() this } - private[clustering] override def next(): EMLDAOptimizer = { + override private[clustering] def next(): EMLDAOptimizer = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") val eta = topicConcentration @@ -202,9 +202,269 @@ class EMLDAOptimizer extends LDAOptimizer{ graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) } - private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() new DistributedLDAModel(this, iterationTimes) } } + + +/** + * :: Experimental :: + * + * An online optimizer for LDA. The Optimizer implements the Online variational Bayes LDA + * algorithm, which processes a subset of the corpus on each iteration, and updates the term-topic + * distribution adaptively. + * + * Original Online LDA paper: + * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. + */ +@Experimental +class OnlineLDAOptimizer extends LDAOptimizer { + + // LDA common parameters + private var k: Int = 0 + private var corpusSize: Long = 0 + private var vocabSize: Int = 0 + + /** alias for docConcentration */ + private var alpha: Double = 0 + + /** (private[clustering] for debugging) Get docConcentration */ + private[clustering] def getAlpha: Double = alpha + + /** alias for topicConcentration */ + private var eta: Double = 0 + + /** (private[clustering] for debugging) Get topicConcentration */ + private[clustering] def getEta: Double = eta + + private var randomGenerator: java.util.Random = null + + // Online LDA specific parameters + // Learning rate is: (tau_0 + t)^{-kappa} + private var tau_0: Double = 1024 + private var kappa: Double = 0.51 + private var miniBatchFraction: Double = 0.05 + + // internal data structure + private var docs: RDD[(Long, Vector)] = null + + /** Dirichlet parameter for the posterior over topics */ + private var lambda: BDM[Double] = null + + /** (private[clustering] for debugging) Get parameter for topics */ + private[clustering] def getLambda: BDM[Double] = lambda + + /** Current iteration (count of invocations of [[next()]]) */ + private var iteration: Int = 0 + private var gammaShape: Double = 100 + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + */ + def getTau_0: Double = this.tau_0 + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * Default: 1024, following the original Online LDA paper. + */ + def setTau_0(tau_0: Double): this.type = { + require(tau_0 > 0, s"LDA tau_0 must be positive, but was set to $tau_0") + this.tau_0 = tau_0 + this + } + + /** + * Learning rate: exponential decay rate + */ + def getKappa: Double = this.kappa + + /** + * Learning rate: exponential decay rate---should be between + * (0.5, 1.0] to guarantee asymptotic convergence. + * Default: 0.51, based on the original Online LDA paper. + */ + def setKappa(kappa: Double): this.type = { + require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") + this.kappa = kappa + this + } + + /** + * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration + */ + def getMiniBatchFraction: Double = this.miniBatchFraction + + /** + * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in + * each iteration. + * + * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Default: 0.05, i.e., 5% of total documents. + */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, + s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") + this.miniBatchFraction = miniBatchFraction + this + } + + /** + * (private[clustering]) + * Set the Dirichlet parameter for the posterior over topics. + * This is only used for testing now. In the future, it can help support training stop/resume. + */ + private[clustering] def setLambda(lambda: BDM[Double]): this.type = { + this.lambda = lambda + this + } + + /** + * (private[clustering]) + * Used for random initialization of the variational parameters. + * Larger value produces values closer to 1.0. + * This is only used for testing currently. + */ + private[clustering] def setGammaShape(shape: Double): this.type = { + this.gammaShape = shape + this + } + + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): OnlineLDAOptimizer = { + this.k = lda.getK + this.corpusSize = docs.count() + this.vocabSize = docs.first()._2.size + this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration + this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration + this.randomGenerator = new Random(lda.getSeed) + + this.docs = docs + + // Initialize the variational distribution q(beta|lambda) + this.lambda = getGammaMatrix(k, vocabSize) + this.iteration = 0 + this + } + + override private[clustering] def next(): OnlineLDAOptimizer = { + val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong()) + if (batch.isEmpty()) return this + submitMiniBatch(batch) + } + + /** + * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA + * model, and it will update the topic distribution adaptively for the terms appearing in the + * subset. + */ + private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = { + iteration += 1 + val k = this.k + val vocabSize = this.vocabSize + val Elogbeta = dirichletExpectation(lambda) + val expElogbeta = exp(Elogbeta) + val alpha = this.alpha + val gammaShape = this.gammaShape + + val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => + val stat = BDM.zeros[Double](k, vocabSize) + docs.foreach { doc => + val termCounts = doc._2 + val (ids: List[Int], cts: Array[Double]) = termCounts match { + case v: DenseVector => ((0 until v.size).toList, v.values) + case v: SparseVector => (v.indices.toList, v.values) + case v => throw new IllegalArgumentException("Online LDA does not support vector type " + + v.getClass) + } + + // Initialize the variational distribution q(theta|gamma) for the mini-batch + var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K + var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K + var expElogthetad = exp(Elogthetad) // 1 * K + val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids + + var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts).t // 1 * ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad + // 1*K 1 * ids ids * k + gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha + Elogthetad = digamma(gammad) - digamma(sum(gammad)) + expElogthetad = exp(Elogthetad) + phinorm = expElogthetad * expElogbetad + 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } + + val m1 = expElogthetad.t + val m2 = (ctsVector / phinorm).t.toDenseVector + var i = 0 + while (i < ids.size) { + stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i) + i += 1 + } + } + Iterator(stat) + } + + val statsSum: BDM[Double] = stats.reduce(_ += _) + val batchResult = statsSum :* expElogbeta + + // Note that this is an optimization to avoid batch.count + update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) + this + } + + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose) + } + + /** + * Update lambda based on the batch submitted. batchSize can be different for each iteration. + */ + private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { + val tau_0 = this.getTau_0 + val kappa = this.getKappa + + // weight of the mini-batch. + val weight = math.pow(tau_0 + iter, -kappa) + + // Update lambda based on documents. + lambda = lambda * (1 - weight) + + (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight + } + + /** + * Get a random matrix to initialize lambda + */ + private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister( + randomGenerator.nextLong())) + val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis) + val temp = gammaRandomGenerator.sample(row * col).toArray + new BDM[Double](col, row, temp).t + } + + /** + * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation + * uses digamma which is accurate but expensive. + */ + private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index fbe171b4b1ab1..f394d903966de 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.ArrayList; -import org.apache.spark.api.java.JavaRDD; import scala.Tuple2; import org.junit.After; @@ -30,6 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; @@ -109,11 +109,45 @@ public void distributedLDAModel() { assert(model.logPrior() < 0.0); } + @Test + public void OnlineOptimizerCompatibility() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + OnlineLDAOptimizer op = new OnlineLDAOptimizer() + .setTau_0(1024) + .setKappa(0.51) + .setGammaShape(1e40) + .setMiniBatchFraction(0.5); + + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op); + + LDAModel model = lda.run(corpus); + + // Check: basic parameters + assertEquals(model.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = model.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); private static Tuple2[] tinyTopicDescription = LDASuite$.MODULE$.tinyTopicDescription(); - JavaPairRDD corpus; + private JavaPairRDD corpus; } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 41ec794146c69..2dcc881f5abd2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import breeze.linalg.{DenseMatrix => BDM} + import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} @@ -37,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Check: describeTopics() with all terms val fullTopicSummary = model.describeTopics() - assert(fullTopicSummary.size === tinyK) + assert(fullTopicSummary.length === tinyK) fullTopicSummary.zip(tinyTopicDescription).foreach { case ((algTerms, algTermWeights), (terms, termWeights)) => assert(algTerms === terms) @@ -54,7 +56,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { } } - test("running and DistributedLDAModel") { + test("running and DistributedLDAModel with default Optimizer (EM)") { val k = 3 val topicSmoothing = 1.2 val termSmoothing = 1.2 @@ -99,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Check: per-doc topic distributions val topicDistributions = model.topicDistributions.collect() // Ensure all documents are covered. - assert(topicDistributions.size === tinyCorpus.size) + assert(topicDistributions.length === tinyCorpus.length) assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) // Ensure we have proper distributions topicDistributions.foreach { case (docId, topicDistribution) => @@ -131,6 +133,87 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } + + test("OnlineLDAOptimizer initialization") { + val lda = new LDA().setK(2) + val corpus = sc.parallelize(tinyCorpus, 2) + val op = new OnlineLDAOptimizer().initialize(corpus, lda) + op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567) + assert(op.getAlpha == 0.5) // default 1.0 / k + assert(op.getEta == 0.5) // default 1.0 / k + assert(op.getKappa == 0.9876) + assert(op.getMiniBatchFraction == 0.123) + assert(op.getTau_0 == 567) + } + + test("OnlineLDAOptimizer one iteration") { + // run OnlineLDAOptimizer for 1 iteration to verify it's consistency with Blei-lab, + // [[https://github.com/Blei-Lab/onlineldavb]] + val k = 2 + val vocabSize = 6 + + def docs: Array[(Long, Vector)] = Array( + Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana + Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val corpus = sc.parallelize(docs, 2) + + // Set GammaShape large to avoid the stochastic impact. + val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40) + .setMiniBatchFraction(1) + val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345) + + val state = op.initialize(corpus, lda) + // override lambda to simulate an intermediate state + // [[ 1.1 1.2 1.3 0.9 0.8 0.7] + // [ 0.9 0.8 0.7 1.1 1.2 1.3]] + op.setLambda(new BDM[Double](k, vocabSize, + Array(1.1, 0.9, 1.2, 0.8, 1.3, 0.7, 0.9, 1.1, 0.8, 1.2, 0.7, 1.3))) + + // run for one iteration + state.submitMiniBatch(corpus) + + // verify the result, Note this generate the identical result as + // [[https://github.com/Blei-Lab/onlineldavb]] + val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") + val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") + assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) + assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + } + + test("OnlineLDAOptimizer with toy data") { + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + val docs = sc.parallelize(toydata) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau_0(1024).setKappa(0.51) + .setGammaShape(1e10) + val lda = new LDA().setK(2) + .setDocConcentration(0.01) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + + val ldaModel = lda.run(docs) + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights) + } + + // check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02) + topics.foreach { topic => + val smalls = topic.filter(t => t._2 < 0.1).map(_._2) + assert(smalls.length == 3 && smalls.sum < 0.2) + } + } + } private[clustering] object LDASuite { From 343d3bfafd449a0371feb6a88f78e07302fa7143 Mon Sep 17 00:00:00 2001 From: tianyi Date: Mon, 4 May 2015 16:59:34 +0800 Subject: [PATCH 41/91] [SPARK-5100] [SQL] add webui for thriftserver This PR is a rebased version of #3946 , and mainly focused on creating an independent tab for the thrift server in spark web UI. Features: 1. Session related statistics ( username and IP are only supported in hive-0.13.1 ) 2. List all the SQL executing or executed on this server 3. Provide links to the job generated by SQL 4. Provide link to show all SQL executing or executed in a specified session Prototype snapshots: This is the main page for thrift server ![image](https://cloud.githubusercontent.com/assets/1411869/7361379/df7dcc64-ed89-11e4-9964-4df0b32f475e.png) Author: tianyi Closes #5730 from tianyi/SPARK-5100 and squashes the following commits: cfd14c7 [tianyi] style fix 0efe3d5 [tianyi] revert part of pom change c0f2fa0 [tianyi] extends HiveThriftJdbcTest to start/stop thriftserver for UI test aa20408 [tianyi] fix style problem c9df6f9 [tianyi] add testsuite for thriftserver ui and fix some style issue 9830199 [tianyi] add webui for thriftserver --- .../scala/org/apache/spark/sql/SQLConf.scala | 2 + sql/hive-thriftserver/pom.xml | 12 ++ .../hive/thriftserver/HiveThriftServer2.scala | 161 +++++++++++++- .../thriftserver/ui/ThriftServerPage.scala | 190 +++++++++++++++++ .../ui/ThriftServerSessionPage.scala | 197 ++++++++++++++++++ .../thriftserver/ui/ThriftServerTab.scala | 50 +++++ .../HiveThriftServer2Suites.scala | 12 +- .../hive/thriftserver/UISeleniumSuite.scala | 105 ++++++++++ .../spark/sql/hive/thriftserver/Shim12.scala | 18 +- .../spark/sql/hive/thriftserver/Shim13.scala | 26 ++- 10 files changed, 751 insertions(+), 22 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2fa602a6082dc..99db959a8741c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -52,6 +52,8 @@ private[spark] object SQLConf { // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" + val THRIFTSERVER_UI_STATEMENT_LIMIT = "spark.sql.thriftserver.ui.retainedStatements" + val THRIFTSERVER_UI_SESSION_LIMIT = "spark.sql.thriftserver.ui.retainedSessions" // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index f38c796241df1..437f697d25bf3 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -57,6 +57,18 @@ ${hive.group} hive-beeline + + + org.seleniumhq.selenium + selenium-java + test + + + io.netty + netty + + + target/scala-${scala.binary.version}/classes diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 832596fc8bee5..0be5a92c2546c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -22,20 +22,27 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.spark.sql.SQLConf -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.Utils +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. */ object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) + var uiTab: Option[ThriftServerTab] = _ + var listener: HiveThriftServer2Listener = _ /** * :: DeveloperApi :: @@ -46,7 +53,13 @@ object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(sqlContext) server.init(sqlContext.hiveconf) server.start() - sqlContext.sparkContext.addSparkListener(new HiveThriftServer2Listener(server)) + listener = new HiveThriftServer2Listener(server, sqlContext.conf) + sqlContext.sparkContext.addSparkListener(listener) + uiTab = if (sqlContext.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) { + Some(new ThriftServerTab(sqlContext.sparkContext)) + } else { + None + } } def main(args: Array[String]) { @@ -58,14 +71,23 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + Utils.addShutdownHook { () => + SparkSQLEnv.stop() + uiTab.foreach(_.detach()) + } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") - SparkSQLEnv.sparkContext.addSparkListener(new HiveThriftServer2Listener(server)) + listener = new HiveThriftServer2Listener(server, SparkSQLEnv.hiveContext.conf) + SparkSQLEnv.sparkContext.addSparkListener(listener) + uiTab = if (SparkSQLEnv.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) { + Some(new ThriftServerTab(SparkSQLEnv.sparkContext)) + } else { + None + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) @@ -73,15 +95,140 @@ object HiveThriftServer2 extends Logging { } } + private[thriftserver] class SessionInfo( + val sessionId: String, + val startTimestamp: Long, + val ip: String, + val userName: String) { + var finishTimestamp: Long = 0L + var totalExecution: Int = 0 + def totalTime: Long = { + if (finishTimestamp == 0L) { + System.currentTimeMillis - startTimestamp + } else { + finishTimestamp - startTimestamp + } + } + } + + private[thriftserver] object ExecutionState extends Enumeration { + val STARTED, COMPILED, FAILED, FINISHED = Value + type ExecutionState = Value + } + + private[thriftserver] class ExecutionInfo( + val statement: String, + val sessionId: String, + val startTimestamp: Long, + val userName: String) { + var finishTimestamp: Long = 0L + var executePlan: String = "" + var detail: String = "" + var state: ExecutionState.Value = ExecutionState.STARTED + val jobId: ArrayBuffer[String] = ArrayBuffer[String]() + var groupId: String = "" + def totalTime: Long = { + if (finishTimestamp == 0L) { + System.currentTimeMillis - startTimestamp + } else { + finishTimestamp - startTimestamp + } + } + } + + /** * A inner sparkListener called in sc.stop to clean up the HiveThriftServer2 */ - class HiveThriftServer2Listener(val server: HiveServer2) extends SparkListener { + private[thriftserver] class HiveThriftServer2Listener( + val server: HiveServer2, + val conf: SQLConf) extends SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - } + val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + val retainedStatements = + conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT, "200").toInt + val retainedSessions = + conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT, "200").toInt + var totalRunning = 0 + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + for { + props <- Option(jobStart.properties) + groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + (_, info) <- executionList if info.groupId == groupId + } { + info.jobId += jobStart.jobId.toString + info.groupId = groupId + } + } + + def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + trimSessionIfNecessary() + } + + def onSessionClosed(sessionId: String): Unit = { + sessionList(sessionId).finishTimestamp = System.currentTimeMillis + } + + def onStatementStart( + id: String, + sessionId: String, + statement: String, + groupId: String, + userName: String = "UNKNOWN"): Unit = { + val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) + info.state = ExecutionState.STARTED + executionList.put(id, info) + trimExecutionIfNecessary() + sessionList(sessionId).totalExecution += 1 + executionList(id).groupId = groupId + totalRunning += 1 + } + + def onStatementParsed(id: String, executionPlan: String): Unit = { + executionList(id).executePlan = executionPlan + executionList(id).state = ExecutionState.COMPILED + } + + def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + } + + def onStatementFinish(id: String): Unit = { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).state = ExecutionState.FINISHED + totalRunning -= 1 + } + + private def trimExecutionIfNecessary() = synchronized { + if (executionList.size > retainedStatements) { + val toRemove = math.max(retainedStatements / 10, 1) + executionList.take(toRemove).foreach { s => + executionList.remove(s._1) + } + } + } + + private def trimSessionIfNecessary() = synchronized { + if (sessionList.size > retainedSessions) { + val toRemove = math.max(retainedSessions / 10, 1) + sessionList.take(toRemove).foreach { s => + sessionList.remove(s._1) + } + } + + } + } } private[hive] class HiveThriftServer2(hiveContext: HiveContext) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala new file mode 100644 index 0000000000000..71b16b6bebffb --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import java.util.Calendar +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{SessionInfo, ExecutionState, ExecutionInfo} +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui._ + + +/** Page for Spark Web UI that shows statistics of a streaming job */ +private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging { + + private val listener = parent.listener + private val startTime = Calendar.getInstance().getTime() + private val emptyCell = "-" + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val content = + generateBasicStats() ++ +
++ +

+ {listener.sessionList.size} session(s) are online, + running {listener.totalRunning} SQL statement(s) +

++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + } + + /** Generate basic stats of the streaming program */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
    +
  • + Started at: {startTime.toString} +
  • +
  • + Time since start: {formatDurationVerbose(timeSinceStart)} +
  • +
+ } + + /** Generate stats of batch statements of the thrift server program */ + private def generateSQLStatsTable(): Seq[Node] = { + val numStatement = listener.executionList.size + val table = if (numStatement > 0) { + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", + "Statement", "State", "Detail") + val dataRows = listener.executionList.values + + def generateDataRow(info: ExecutionInfo): Seq[Node] = { + val jobLink = info.jobId.map { id: String => + + [{id}] + + } + val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + + {info.userName} + + {jobLink} + + {info.groupId} + {formatDate(info.startTimestamp)} + {if(info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {formatDurationOption(Some(info.totalTime))} + {info.statement} + {info.state} + {errorMessageCell(detail)} + + } + + Some(UIUtils.listingTable(headerRow, generateDataRow, + dataRows, false, None, Seq(null), false)) + } else { + None + } + + val content = +
SQL Statistics
++ +
+
    + {table.getOrElse("No statistics have been generated yet.")} +
+
+ + content + } + + private def errorMessageCell(errorMessage: String): Seq[Node] = { + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = if (isMultiline) { + // scalastyle:off + + + details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + + /** Generate stats of batch sessions of the thrift server program */ + private def generateSessionStatsTable(): Seq[Node] = { + val numBatches = listener.sessionList.size + val table = if (numBatches > 0) { + val dataRows = + listener.sessionList.values + val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", + "Total Execute") + def generateDataRow(session: SessionInfo): Seq[Node] = { + val sessionLink = "%s/ThriftServer/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + + {session.userName} + {session.ip} + {session.sessionId} , + {formatDate(session.startTimestamp)} + {if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {formatDurationOption(Some(session.totalTime))} + {session.totalExecution.toString} + + } + Some(UIUtils.listingTable(headerRow, generateDataRow, dataRows, true, None, Seq(null), false)) + } else { + None + } + + val content = +
Session Statistics
++ +
+
    + {table.getOrElse("No statistics have been generated yet.")} +
+
+ + content + } + + + /** + * Returns a human-readable string representing a duration such as "5 second 35 ms" + */ + private def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(formatDurationVerbose).getOrElse(emptyCell) + } + + /** Generate HTML table from string data */ + private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { + def generateDataRow(data: Seq[String]): Seq[Node] = { + {data.map(d => {d})} + } + UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala new file mode 100644 index 0000000000000..33ba038ecce73 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import java.util.Calendar +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui._ + +/** Page for Spark Web UI that shows statistics of a streaming job */ +private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) + extends WebUIPage("session") with Logging { + + private val listener = parent.listener + private val startTime = Calendar.getInstance().getTime() + private val emptyCell = "-" + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + val sessionStat = listener.sessionList.find(stat => { + stat._1 == parameterId + }).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + val content = + generateBasicStats() ++ +
++ +

+ User {sessionStat._2.userName}, + IP {sessionStat._2.ip}, + Session created at {formatDate(sessionStat._2.startTimestamp)}, + Total run {sessionStat._2.totalExecution} SQL +

++ + generateSQLStatsTable(sessionStat._2.sessionId) + UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + } + + /** Generate basic stats of the streaming program */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
    +
  • + Started at: {startTime.toString} +
  • +
  • + Time since start: {formatDurationVerbose(timeSinceStart)} +
  • +
+ } + + /** Generate stats of batch statements of the thrift server program */ + private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + val executionList = listener.executionList + .filter(_._2.sessionId == sessionID) + val numStatement = executionList.size + val table = if (numStatement > 0) { + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", + "Statement", "State", "Detail") + val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + + def generateDataRow(info: ExecutionInfo): Seq[Node] = { + val jobLink = info.jobId.map { id: String => + + [{id}] + + } + val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + + {info.userName} + + {jobLink} + + {info.groupId} + {formatDate(info.startTimestamp)} + {formatDate(info.finishTimestamp)} + {formatDurationOption(Some(info.totalTime))} + {info.statement} + {info.state} + {errorMessageCell(detail)} + + } + + Some(UIUtils.listingTable(headerRow, generateDataRow, + dataRows, false, None, Seq(null), false)) + } else { + None + } + + val content = +
SQL Statistics
++ +
+
    + {table.getOrElse("No statistics have been generated yet.")} +
+
+ + content + } + + private def errorMessageCell(errorMessage: String): Seq[Node] = { + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = if (isMultiline) { + // scalastyle:off + + + details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + + /** Generate stats of batch sessions of the thrift server program */ + private def generateSessionStatsTable(): Seq[Node] = { + val numBatches = listener.sessionList.size + val table = if (numBatches > 0) { + val dataRows = + listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + Seq( + session.userName, + session.ip, + session.sessionId, + formatDate(session.startTimestamp), + formatDate(session.finishTimestamp), + formatDurationOption(Some(session.totalTime)), + session.totalExecution.toString + ) + ).toSeq + val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", + "Total Execute") + Some(listingTable(headerRow, dataRows)) + } else { + None + } + + val content = +
Session Statistics
++ +
+
    + {table.getOrElse("No statistics have been generated yet.")} +
+
+ + content + } + + + /** + * Returns a human-readable string representing a duration such as "5 second 35 ms" + */ + private def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(formatDurationVerbose).getOrElse(emptyCell) + } + + /** Generate HTML table from string data */ + private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { + def generateDataRow(data: Seq[String]): Seq[Node] = { + {data.map(d => {d})} + } + UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala new file mode 100644 index 0000000000000..343031f10c75c --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import org.apache.spark.sql.hive.thriftserver.{HiveThriftServer2, SparkSQLEnv} +import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ +import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.{SparkContext, Logging, SparkException} + +/** + * Spark Web UI tab that shows statistics of a streaming job. + * This assumes the given SparkContext has enabled its SparkUI. + */ +private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) + extends SparkUITab(getSparkUI(sparkContext), "ThriftServer") with Logging { + + val parent = getSparkUI(sparkContext) + val listener = HiveThriftServer2.listener + + attachPage(new ThriftServerPage(this)) + attachPage(new ThriftServerSessionPage(this)) + parent.attachTab(this) + + def detach() { + getSparkUI(sparkContext).detachTab(this) + } +} + +private[thriftserver] object ThriftServerTab { + def getSparkUI(sparkContext: SparkContext): SparkUI = { + sparkContext.ui.getOrElse { + throw new SparkException("Parent SparkUI to attach this tab to not found!") + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4cf95e7bdfb2b..1fadea97fd07f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -409,24 +409,24 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " - private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + protected val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + protected val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) private var listeningPort: Int = _ protected def serverPort: Int = listeningPort protected def user = System.getProperty("user.name") - private var warehousePath: File = _ - private var metastorePath: File = _ - private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + protected var warehousePath: File = _ + protected var metastorePath: File = _ + protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") private var logPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] - private def serverStartCommand(port: Int) = { + protected def serverStartCommand(port: Int) = { val portConf = if (mode == ServerMode.binary) { ConfVars.HIVE_SERVER2_THRIFT_PORT } else { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala new file mode 100644 index 0000000000000..47541015a3611 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + + + +import scala.util.Random + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest.{Matchers, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.sql.hive.HiveContext + + +class UISeleniumSuite + extends HiveThriftJdbcTest + with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver = _ + var server: HiveThriftServer2 = _ + var hc: HiveContext = _ + val uiPort = 20000 + Random.nextInt(10000) + override def mode: ServerMode.Value = ServerMode.binary + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + super.beforeAll() + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + super.afterAll() + } + + override protected def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + s"""$startScript + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=true + | --conf spark.ui.port=$uiPort + """.stripMargin.split("\\s+").toSeq + } + + test("thrift server ui test") { + withJdbcStatement(statement =>{ + val baseURL = s"http://localhost:${uiPort}" + + val queries = Seq( + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (baseURL) + find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be(None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (baseURL + "/ThriftServer") + find(id("sessionstat")) should not be(None) + find(id("sqlstat")) should not be(None) + + // check whether statements exists + queries.foreach { line => + findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) + } + } + }) + } +} diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 95a6e86d0546d..b3a79ba1c7d6b 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, Map => JMap} +import java.util.{ArrayList => JArrayList, Map => JMap, UUID} import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf @@ -190,9 +190,12 @@ private[hive] class SparkExecuteStatementOperation( } def run(): Unit = { + val statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) - hiveContext.sparkContext.setJobDescription(statement) + HiveThriftServer2.listener.onStatementStart( + statementId, parentSession.getSessionHandle.getSessionId.toString, statement, statementId) + hiveContext.sparkContext.setJobGroup(statementId, statement) sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } @@ -205,6 +208,7 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } + HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean @@ -221,10 +225,13 @@ private[hive] class SparkExecuteStatementOperation( // HiveServer will silently swallow them. case e: Throwable => setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, e.getStackTraceString) logError("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) + HiveThriftServer2.listener.onStatementFinish(statementId) } } @@ -255,11 +262,14 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) withImpersonation: Boolean, delegationToken: String): SessionHandle = { hiveContext.openSession() - - super.openSession(username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = super.openSession( + username, passwd, sessionConf, withImpersonation, delegationToken) + HiveThriftServer2.listener.onSessionCreated("UNKNOWN", sessionHandle.getSessionId.toString) + sessionHandle } override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 178eb1af7cdcd..b9d4f1c58c982 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf @@ -36,7 +36,7 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.{SessionManager, HiveSession} -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -161,9 +161,16 @@ private[hive] class SparkExecuteStatementOperation( } def run(): Unit = { + val statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) - hiveContext.sparkContext.setJobDescription(statement) + HiveThriftServer2.listener.onStatementStart( + statementId, + parentSession.getSessionHandle.getSessionId.toString, + statement, + statementId, + parentSession.getUsername) + hiveContext.sparkContext.setJobGroup(statementId, statement) sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) } @@ -176,6 +183,7 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } + HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean @@ -192,10 +200,13 @@ private[hive] class SparkExecuteStatementOperation( // HiveServer will silently swallow them. case e: Throwable => setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, e.getStackTraceString) logError("Error executing query:", e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) + HiveThriftServer2.listener.onStatementFinish(statementId) } } @@ -227,11 +238,16 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) withImpersonation: Boolean, delegationToken: String): SessionHandle = { hiveContext.openSession() - - super.openSession(protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle } override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle From 5a1a1075a607be683f008ef92fa227803370c45f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 4 May 2015 17:17:55 +0100 Subject: [PATCH 42/91] [MINOR] Fix python test typo? I suspect haven't been using anaconda in tests in a while. I wonder if this change actually does anything but this line as it stands looks strictly less correct. Author: Andrew Or Closes #5883 from andrewor14/fix-run-tests-typo and squashes the following commits: a3ad720 [Andrew Or] Fix typo? --- dev/run-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 861d1671182c2..05c63bce4d40d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -236,7 +236,7 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS # add path for python 3 in jenkins -export PATH="${PATH}:/home/anaonda/envs/py3k/bin" +export PATH="${PATH}:/home/anaconda/envs/py3k/bin" ./python/run-tests echo "" From e0833c5958bbd73ff27cfe6865648d7b6e5a99bc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 4 May 2015 11:28:59 -0700 Subject: [PATCH 43/91] [SPARK-5956] [MLLIB] Pipeline components should be copyable. This PR added `copy(extra: ParamMap): Params` to `Params`, which makes a copy of the current instance with a randomly generated uid and some extra param values. With this change, we only need to implement `fit` and `transform` without extra param values given the default implementation of `fit(dataset, extra)`: ~~~scala def fit(dataset: DataFrame, extra: ParamMap): Model = { copy(extra).fit(dataset) } ~~~ Inside `fit` and `transform`, since only the embedded values are used, I added `$` as an alias for `getOrDefault` to make the code easier to read. For example, in `LinearRegression.fit` we have: ~~~scala val effectiveRegParam = $(regParam) / yStd val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam ~~~ Meta-algorithm like `Pipeline` implements its own `copy(extra)`. So the fitted pipeline model stored all copied stages (no matter whether it is a transformer or a model). Other changes: * `Params$.inheritValues` is moved to `Params!.copyValues` and returns the target instance. * `fittingParamMap` was removed because the `parent` carries this information. * `validate` was renamed to `validateParams` to be more precise. TODOs: * [x] add tests for newly added methods * [ ] update documentation jkbradley dbtsai Author: Xiangrui Meng Closes #5820 from mengxr/SPARK-5956 and squashes the following commits: 7bef88d [Xiangrui Meng] address comments 05229c3 [Xiangrui Meng] assert -> assertEquals b2927b1 [Xiangrui Meng] organize imports f14456b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 93e7924 [Xiangrui Meng] add tests for hasParam & copy 463ecae [Xiangrui Meng] merge master 2b954c3 [Xiangrui Meng] update Binarizer 465dd12 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 282a1a8 [Xiangrui Meng] fix test 819dd2d [Xiangrui Meng] merge master b642872 [Xiangrui Meng] example code runs 5a67779 [Xiangrui Meng] examples compile c76b4d1 [Xiangrui Meng] fix all unit tests 0f4fd64 [Xiangrui Meng] fix some tests 9286a22 [Xiangrui Meng] copyValues to trained models 53e0973 [Xiangrui Meng] move inheritValues to Params and rename it to copyValues 9ee004e [Xiangrui Meng] merge copy and copyWith; rename validate to validateParams d882afc [Xiangrui Meng] test compile f082a31 [Xiangrui Meng] make Params copyable and simply handling of extra params in all spark.ml components --- .../examples/ml/JavaDeveloperApiExample.java | 24 ++-- .../examples/ml/JavaSimpleParamsExample.java | 4 +- .../examples/ml/DecisionTreeExample.scala | 6 +- .../examples/ml/DeveloperApiExample.scala | 22 ++-- .../apache/spark/examples/ml/GBTExample.scala | 4 +- .../examples/ml/RandomForestExample.scala | 6 +- .../examples/ml/SimpleParamsExample.scala | 4 +- .../scala/org/apache/spark/ml/Estimator.scala | 26 ++++- .../scala/org/apache/spark/ml/Evaluator.scala | 20 +++- .../scala/org/apache/spark/ml/Model.scala | 9 +- .../scala/org/apache/spark/ml/Pipeline.scala | 106 ++++++++---------- .../org/apache/spark/ml/Transformer.scala | 46 +++++--- .../spark/ml/classification/Classifier.scala | 49 +++----- .../DecisionTreeClassifier.scala | 29 ++--- .../ml/classification/GBTClassifier.scala | 33 +++--- .../classification/LogisticRegression.scala | 58 +++++----- .../ProbabilisticClassifier.scala | 33 ++---- .../RandomForestClassifier.scala | 31 ++--- .../BinaryClassificationEvaluator.scala | 17 ++- .../apache/spark/ml/feature/Binarizer.scala | 20 ++-- .../apache/spark/ml/feature/HashingTF.scala | 10 +- .../org/apache/spark/ml/feature/IDF.scala | 38 +++---- .../apache/spark/ml/feature/Normalizer.scala | 10 +- .../ml/feature/PolynomialExpansion.scala | 9 +- .../spark/ml/feature/StandardScaler.scala | 49 ++++---- .../spark/ml/feature/StringIndexer.scala | 34 +++--- .../apache/spark/ml/feature/Tokenizer.scala | 18 +-- .../spark/ml/feature/VectorAssembler.scala | 15 +-- .../spark/ml/feature/VectorIndexer.scala | 74 ++++++------ .../apache/spark/ml/feature/Word2Vec.scala | 62 +++++----- .../spark/ml/impl/estimator/Predictor.scala | 72 ++++-------- .../spark/ml/impl/tree/treeParams.scala | 35 +++--- .../org/apache/spark/ml/param/params.scala | 75 ++++++++----- .../ml/param/shared/SharedParamsCodeGen.scala | 5 +- .../spark/ml/param/shared/sharedParams.scala | 35 +++--- .../apache/spark/ml/recommendation/ALS.scala | 73 ++++++------ .../ml/regression/DecisionTreeRegressor.scala | 23 ++-- .../spark/ml/regression/GBTRegressor.scala | 30 ++--- .../ml/regression/LinearRegression.scala | 41 ++++--- .../ml/regression/RandomForestRegressor.scala | 27 ++--- .../spark/ml/regression/Regressor.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 51 ++++----- .../JavaLogisticRegressionSuite.java | 14 ++- .../regression/JavaLinearRegressionSuite.java | 21 ++-- .../ml/tuning/JavaCrossValidatorSuite.java | 6 +- .../org/apache/spark/ml/PipelineSuite.scala | 26 +++-- .../DecisionTreeClassifierSuite.scala | 4 +- .../classification/GBTClassifierSuite.scala | 4 +- .../LogisticRegressionSuite.scala | 18 +-- .../RandomForestClassifierSuite.scala | 4 +- .../apache/spark/ml/param/ParamsSuite.scala | 13 ++- .../apache/spark/ml/param/TestParams.scala | 14 ++- .../DecisionTreeRegressorSuite.scala | 4 +- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../RandomForestRegressorSuite.scala | 4 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 6 +- 56 files changed, 671 insertions(+), 805 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 46377a99c4857..eac4f898a475d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.param.Params$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -129,16 +128,16 @@ MyJavaLogisticRegression setMaxIter(int value) { // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) { + public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Extract columns from data using helper method. - JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD(); + JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); // Do learning to estimate the weight vector. int numFeatures = oldDataset.take(1).get(0).features().size(); Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(this, paramMap, weights); + return new MyJavaLogisticRegressionModel(this, weights); } } @@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel private MyJavaLogisticRegression parent_; public MyJavaLogisticRegression parent() { return parent_; } - private ParamMap fittingParamMap_; - public ParamMap fittingParamMap() { return fittingParamMap_; } - private Vector weights_; public Vector weights() { return weights_; } - public MyJavaLogisticRegressionModel( - MyJavaLogisticRegression parent_, - ParamMap fittingParamMap_, - Vector weights_) { + public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) { this.parent_ = parent_; - this.fittingParamMap_ = fittingParamMap_; this.weights_ = weights_; } @@ -210,10 +202,8 @@ public Vector predictRaw(Vector features) { * In Java, we have to make this method public since Java does not understand Scala's protected * modifier. */ - public MyJavaLogisticRegressionModel copy() { - MyJavaLogisticRegressionModel m = - new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); - Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); - return m; + @Override + public MyJavaLogisticRegressionModel copy(ParamMap extra) { + return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 4e02acce696e6..29158d5c85651 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -71,7 +71,7 @@ public static void main(String[] args) { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); @@ -87,7 +87,7 @@ public static void main(String[] args) { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); - System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 9002e99d82ad3..8340d91101ab3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -276,16 +276,14 @@ object DecisionTreeExample { // Get the trained Decision Tree from the fitted PipelineModel algo match { case "classification" => - val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel]( - dt.asInstanceOf[DecisionTreeClassifier]) + val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel] if (treeModel.numNodes < 20) { println(treeModel.toDebugString) // Print full model. } else { println(treeModel) // Print model summary. } case "regression" => - val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel]( - dt.asInstanceOf[DecisionTreeRegressor]) + val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel] if (treeModel.numNodes < 20) { println(treeModel.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 2245fa429fda3..2a2d0677272a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -18,13 +18,12 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} -import org.apache.spark.ml.param.{Params, IntParam, ParamMap} +import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} +import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.{DataFrame, Row, SQLContext} - /** * A simple example demonstrating how to write your own learning algorithm using Estimator, * Transformer, and other abstractions. @@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = getOrDefault(maxIter) + def getMaxIter: Int = $(maxIter) } /** @@ -117,18 +116,16 @@ private class MyLogisticRegression def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): MyLogisticRegressionModel = { + override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { // Extract columns from data using helper method. - val oldDataset = extractLabeledPoints(dataset, paramMap) + val oldDataset = extractLabeledPoints(dataset) // Do learning to estimate the weight vector. val numFeatures = oldDataset.take(1)(0).features.size val weights = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(this, paramMap, weights) + new MyLogisticRegressionModel(this, weights) } } @@ -139,7 +136,6 @@ private class MyLogisticRegression */ private class MyLogisticRegressionModel( override val parent: MyLogisticRegression, - override val fittingParamMap: ParamMap, val weights: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -176,9 +172,7 @@ private class MyLogisticRegressionModel( * * This is used for the default implementation of [[transform()]]. */ - override protected def copy(): MyLogisticRegressionModel = { - val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) - Params.inheritValues(extractParamMap(), this, m) - m + override def copy(extra: ParamMap): MyLogisticRegressionModel = { + copyValues(new MyLogisticRegressionModel(parent, weights), extra) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 5fccb142d4c3d..c5899b6683c79 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -201,14 +201,14 @@ object GBTExample { // Get the trained GBT from the fitted PipelineModel algo match { case "classification" => - val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier]) + val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel] if (rfModel.totalNumNodes < 30) { println(rfModel.toDebugString) // Print full model. } else { println(rfModel) // Print model summary. } case "regression" => - val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor]) + val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel] if (rfModel.totalNumNodes < 30) { println(rfModel.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9b909324ec82a..7f88d2681bcaa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -209,16 +209,14 @@ object RandomForestExample { // Get the trained Random Forest from the fitted PipelineModel algo match { case "classification" => - val rfModel = pipelineModel.getModel[RandomForestClassificationModel]( - dt.asInstanceOf[RandomForestClassifier]) + val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel] if (rfModel.totalNumNodes < 30) { println(rfModel.toDebugString) // Print full model. } else { println(rfModel) // Print model summary. } case "regression" => - val rfModel = pipelineModel.getModel[RandomForestRegressionModel]( - dt.asInstanceOf[RandomForestRegressor]) + val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel] if (rfModel.totalNumNodes < 30) { println(rfModel.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index bf805149d0af6..e8a991f50e338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -63,7 +63,7 @@ object SimpleParamsExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.fittingParamMap) + println("Model 1 was fit using parameters: " + model1.parent.extractParamMap()) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -78,7 +78,7 @@ object SimpleParamsExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training.toDF(), paramMapCombined) - println("Model 2 was fit using parameters: " + model2.fittingParamMap) + println("Model 2 was fit using parameters: " + model2.parent.extractParamMap()) // Prepare test data. val test = sc.parallelize(Seq( diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index d6b3503ebdd9a..7f3f3262a644f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -34,13 +34,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs Optional list of param pairs. - * These values override any specified in this Estimator's embedded ParamMap. + * @param firstParamPair the first param pair, overrides embedded params + * @param otherParamPairs other param pairs. These values override any specified in this + * Estimator's embedded ParamMap. * @return fitted model */ @varargs - def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = ParamMap(paramPairs: _*) + def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + val map = new ParamMap() + .put(firstParamPair) + .put(otherParamPairs: _*) fit(dataset, map) } @@ -52,12 +55,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M = { + copy(paramMap).fit(dataset) + } + + /** + * Fits a model to the input data. + */ + def fit(dataset: DataFrame): M /** * Fits multiple models to the input data with multiple sets of parameters. * The default implementation uses a for loop on each parameter map. - * Subclasses could overwrite this to optimize multi-model training. + * Subclasses could override this to optimize multi-model training. * * @param dataset input dataset * @param paramMaps An array of parameter maps. @@ -67,4 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } + + override def copy(extra: ParamMap): Estimator[M] = { + super.copy(extra).asInstanceOf[Estimator[M]] + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index 8b4b5fd8af986..5f2f8c94e9ff7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -18,8 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame /** @@ -27,7 +26,7 @@ import org.apache.spark.sql.DataFrame * Abstract class for evaluators that compute metrics from predictions. */ @AlphaComponent -abstract class Evaluator extends Identifiable { +abstract class Evaluator extends Params { /** * Evaluates the output. @@ -36,5 +35,18 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + this.copy(paramMap).evaluate(dataset) + } + + /** + * Evaluates the output. + * @param dataset a dataset that contains labels/observations and predictions. + * @return metric + */ + def evaluate(dataset: DataFrame): Double + + override def copy(extra: ParamMap): Evaluator = { + super.copy(extra).asInstanceOf[Evaluator] + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index a491bc7ee8295..9974efe7b1d25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -34,9 +34,8 @@ abstract class Model[M <: Model[M]] extends Transformer { */ val parent: Estimator[M] - /** - * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. - * Note: For ensembles' component Models, this value can be null. - */ - val fittingParamMap: ParamMap + override def copy(extra: ParamMap): M = { + // The default implementation of Params.copy doesn't work for models. + throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6bfeecd764d75..33d430f5671ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{Params, Param, ParamMap} +import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -30,40 +30,41 @@ import org.apache.spark.sql.types.StructType * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ @AlphaComponent -abstract class PipelineStage extends Serializable with Logging { +abstract class PipelineStage extends Params with Logging { /** * :: DeveloperApi :: * - * Derives the output schema from the input schema and parameters. - * The schema describes the columns and types of the data. - * - * @param schema Input schema to this stage - * @param paramMap Parameters passed to this stage - * @return Output schema from this stage + * Derives the output schema from the input schema. */ @DeveloperApi - def transformSchema(schema: StructType, paramMap: ParamMap): StructType + def transformSchema(schema: StructType): StructType /** + * :: DeveloperApi :: + * * Derives the output schema from the input schema and parameters, optionally with logging. * * This should be optimistic. If it is unclear whether the schema will be valid, then it should * be assumed valid until proven otherwise. */ + @DeveloperApi protected def transformSchema( schema: StructType, - paramMap: ParamMap, logging: Boolean): StructType = { if (logging) { logDebug(s"Input schema: ${schema.json}") } - val outputSchema = transformSchema(schema, paramMap) + val outputSchema = transformSchema(schema) if (logging) { logDebug(s"Expected output schema: ${outputSchema.json}") } outputSchema } + + override def copy(extra: ParamMap): PipelineStage = { + super.copy(extra).asInstanceOf[PipelineStage] + } } /** @@ -81,15 +82,22 @@ abstract class PipelineStage extends Serializable with Logging { @AlphaComponent class Pipeline extends Estimator[PipelineModel] { - /** param for pipeline stages */ + /** + * param for pipeline stages + * @group param + */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + + /** @group setParam */ def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = getOrDefault(stages) - override def validate(paramMap: ParamMap): Unit = { + /** @group getParam */ + def getStages: Array[PipelineStage] = $(stages).clone() + + override def validateParams(paramMap: ParamMap): Unit = { val map = extractParamMap(paramMap) getStages.foreach { - case pStage: Params => pStage.validate(map) + case pStage: Params => pStage.validateParams(map) case _ => } } @@ -104,13 +112,11 @@ class Pipeline extends Estimator[PipelineModel] { * pipeline stages. If there are no stages, the output model acts as an identity transformer. * * @param dataset input dataset - * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val theStages = map(stages) + override def fit(dataset: DataFrame): PipelineModel = { + transformSchema(dataset.schema, logging = true) + val theStages = $(stages) // Search for the last estimator. var indexOfLastEstimator = -1 theStages.view.zipWithIndex.foreach { case (stage, index) => @@ -126,7 +132,7 @@ class Pipeline extends Estimator[PipelineModel] { if (index <= indexOfLastEstimator) { val transformer = stage match { case estimator: Estimator[_] => - estimator.fit(curDataset, paramMap) + estimator.fit(curDataset) case t: Transformer => t case _ => @@ -134,7 +140,7 @@ class Pipeline extends Estimator[PipelineModel] { s"Do not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { - curDataset = transformer.transform(curDataset, paramMap) + curDataset = transformer.transform(curDataset) } transformers += transformer } else { @@ -142,15 +148,20 @@ class Pipeline extends Estimator[PipelineModel] { } } - new PipelineModel(this, map, transformers.toArray) + new PipelineModel(this, transformers.toArray) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val theStages = map(stages) + override def copy(extra: ParamMap): Pipeline = { + val map = extractParamMap(extra) + val newStages = map(stages).map(_.copy(extra)) + new Pipeline().setStages(newStages) + } + + override def transformSchema(schema: StructType): StructType = { + val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") - theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } } @@ -161,43 +172,24 @@ class Pipeline extends Estimator[PipelineModel] { @AlphaComponent class PipelineModel private[ml] ( override val parent: Pipeline, - override val fittingParamMap: ParamMap, - private[ml] val stages: Array[Transformer]) + val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { - override def validate(paramMap: ParamMap): Unit = { - val map = fittingParamMap ++ extractParamMap(paramMap) - stages.foreach(_.validate(map)) + override def validateParams(): Unit = { + super.validateParams() + stages.foreach(_.validateParams()) } - /** - * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input - * estimator does not exist in the pipeline. - */ - def getModel[M <: Model[M]](stage: Estimator[M]): M = { - val matched = stages.filter { - case m: Model[_] => m.parent.eq(stage) - case _ => false - } - if (matched.isEmpty) { - throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") - } else if (matched.length > 1) { - throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") - } else { - matched.head.asInstanceOf[M] - } + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = fittingParamMap ++ extractParamMap(paramMap) - transformSchema(dataset.schema, map, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) + override def transformSchema(schema: StructType): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = fittingParamMap ++ extractParamMap(paramMap) - stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) + override def copy(extra: ParamMap): PipelineModel = { + new PipelineModel(parent, stages) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 0acda71ec6045..d96b54e511e9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -37,13 +37,18 @@ abstract class Transformer extends PipelineStage with Params { /** * Transforms the dataset with optional parameters * @param dataset input dataset - * @param paramPairs optional list of param pairs, overwrite embedded params + * @param firstParamPair the first param pair, overwrite embedded params + * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ @varargs - def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { + def transform( + dataset: DataFrame, + firstParamPair: ParamPair[_], + otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() - paramPairs.foreach(map.put(_)) + .put(firstParamPair) + .put(otherParamPairs: _*) transform(dataset, map) } @@ -53,7 +58,18 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + this.copy(paramMap).transform(dataset) + } + + /** + * Transforms the input dataset. + */ + def transform(dataset: DataFrame): DataFrame + + override def copy(extra: ParamMap): Transformer = { + super.copy(extra).asInstanceOf[Transformer] + } } /** @@ -74,7 +90,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O * account of the embedded param map. So the param values should be determined solely by the input * param map. */ - protected def createTransformFunc(paramMap: ParamMap): IN => OUT + protected def createTransformFunc: IN => OUT /** * Returns the data type of the output column. @@ -86,22 +102,20 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O */ protected def validateInputType(inputType: DataType): Unit = {} - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType validateInputType(inputType) - if (schema.fieldNames.contains(map(outputCol))) { - throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") } val outputFields = schema.fields :+ - StructField(map(outputCol), outputDataType, nullable = false) + StructField($(outputCol), outputDataType, nullable = false) StructType(outputFields) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - dataset.withColumn(map(outputCol), - callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + dataset.withColumn($(outputCol), + callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 29339c98f51cf..d3361e24705c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -27,7 +26,6 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} - /** * :: DeveloperApi :: * Params for classification. @@ -40,12 +38,10 @@ private[spark] trait ClassifierParams extends PredictorParams override protected def validateAndTransformSchema( schema: StructType, - paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = extractParamMap(paramMap) - SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } } @@ -102,27 +98,16 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. * * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + override def transform(dataset: DataFrame): DataFrame = { // This default implementation should be overridden as needed. // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - - // Prepare model - val tmpModel = if (paramMap.size != 0) { - val tmpModel = this.copy() - Params.inheritValues(paramMap, parent, tmpModel) - tmpModel - } else { - this - } + transformSchema(dataset.schema, logging = true) val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) if (numColsOutput == 0) { logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") @@ -158,7 +143,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur */ @DeveloperApi protected def predictRaw(features: FeaturesType): Vector - } private[ml] object ClassificationModel { @@ -167,38 +151,35 @@ private[ml] object ClassificationModel { * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. * @param dataset Input dataset - * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge - * should already be done. * @return (number of columns added, transformed dataset) */ def transformColumnsImpl[FeaturesType]( dataset: DataFrame, - model: ClassificationModel[FeaturesType, _], - map: ParamMap): (Int, DataFrame) = { + model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = { // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. var tmpData = dataset var numColsOutput = 0 - if (map(model.rawPredictionCol) != "") { + if (model.getRawPredictionCol != "") { // output raw prediction val features2raw: FeaturesType => Vector = model.predictRaw - tmpData = tmpData.withColumn(map(model.rawPredictionCol), - callUDF(features2raw, new VectorUDT, col(map(model.featuresCol)))) + tmpData = tmpData.withColumn(model.getRawPredictionCol, + callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol))) numColsOutput += 1 - if (map(model.predictionCol) != "") { + if (model.getPredictionCol != "") { val raw2pred: Vector => Double = (rawPred) => { rawPred.toArray.zipWithIndex.maxBy(_._1)._2 } - tmpData = tmpData.withColumn(map(model.predictionCol), - callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol)))) + tmpData = tmpData.withColumn(model.getPredictionCol, + callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol))) numColsOutput += 1 } - } else if (map(model.predictionCol) != "") { + } else if (model.getPredictionCol != "") { // output prediction val features2pred: FeaturesType => Double = model.predict - tmpData = tmpData.withColumn(map(model.predictionCol), - callUDF(features2pred, DoubleType, col(map(model.featuresCol)))) + tmpData = tmpData.withColumn(model.getPredictionCol, + callUDF(features2pred, DoubleType, col(model.getFeaturesCol))) numColsOutput += 1 } (numColsOutput, tmpData) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index ee2a8dc6db171..419e5ba05d38a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector @@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -64,22 +63,20 @@ final class DecisionTreeClassifier override def setImpurity(value: String): this.type = super.setImpurity(value) - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): DecisionTreeClassificationModel = { + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + - s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + " specified. See StringIndexer.") // TODO: Automatically index labels: SPARK-7126 } - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val oldModel = OldDecisionTree.train(oldDataset, strategy) - DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures) } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -106,7 +103,6 @@ object DecisionTreeClassifier { @AlphaComponent final class DecisionTreeClassificationModel private[ml] ( override val parent: DecisionTreeClassifier, - override val fittingParamMap: ParamMap, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -118,10 +114,8 @@ final class DecisionTreeClassificationModel private[ml] ( rootNode.predict(features) } - override protected def copy(): DecisionTreeClassificationModel = { - val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): DecisionTreeClassificationModel = { + copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra) } override def toString: String = { @@ -140,12 +134,11 @@ private[ml] object DecisionTreeClassificationModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + new DecisionTreeClassificationModel(parent, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3d849867d4c47..534ea95b1c538 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.param.{Param, Params, ParamMap} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils @@ -31,12 +31,11 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss} +import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -112,7 +111,7 @@ final class GBTClassifier def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { @@ -124,25 +123,23 @@ final class GBTClassifier } } - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): GBTClassificationModel = { + override protected def train(dataset: DataFrame): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("GBTClassifier was given input" + - s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + " specified. See StringIndexer.") // TODO: Automatically index labels: SPARK-7126 } require(numClasses == 2, s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) } } @@ -165,7 +162,6 @@ object GBTClassifier { @AlphaComponent final class GBTClassificationModel( override val parent: GBTClassifier, - override val fittingParamMap: ParamMap, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTClassificationModel] @@ -188,10 +184,8 @@ final class GBTClassificationModel( if (prediction > 0.0) 1.0 else 0.0 } - override protected def copy(): GBTClassificationModel = { - val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): GBTClassificationModel = { + copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra) } override def toString: String = { @@ -210,14 +204,13 @@ private[ml] object GBTClassificationModel { def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => // parent, fittingParamMap for each tree is null since there are no good ways to set these. - DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + new GBTClassificationModel(parent, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index cc8b0721cf2b6..b73be035e29b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,12 +21,11 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel - /** * Params for logistic regression. */ @@ -59,9 +58,9 @@ class LogisticRegression /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { + override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + val oldDataset = extractLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { oldDataset.persist(StorageLevel.MEMORY_AND_DISK) @@ -69,17 +68,17 @@ class LogisticRegression // Train model val lr = new LogisticRegressionWithLBFGS() - .setIntercept(paramMap(fitIntercept)) + .setIntercept($(fitIntercept)) lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) + .setRegParam($(regParam)) + .setNumIterations($(maxIter)) val oldModel = lr.run(oldDataset) - val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) + val lrm = new LogisticRegressionModel(this, oldModel.weights, oldModel.intercept) if (handlePersistence) { oldDataset.unpersist() } - lrm + copyValues(lrm) } } @@ -92,7 +91,6 @@ class LogisticRegression @AlphaComponent class LogisticRegressionModel private[ml] ( override val parent: LogisticRegression, - override val fittingParamMap: ParamMap, val weights: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] @@ -110,16 +108,14 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + override def transform(dataset: DataFrame): DataFrame = { // This is overridden (a) to be more efficient (avoiding re-computing values when creating // multiple output columns) and (b) to handle threshold, which the abstractions do not use. // TODO: We should abstract away the steps defined by UDFs below so that the abstractions // can call whichever UDFs are needed to create the output columns. // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - - val map = extractParamMap(paramMap) + transformSchema(dataset.schema, logging = true) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -128,41 +124,41 @@ class LogisticRegressionModel private[ml] ( // prediction (max margin) var tmpData = dataset var numColsOutput = 0 - if (map(rawPredictionCol) != "") { + if ($(rawPredictionCol) != "") { val features2raw: Vector => Vector = (features) => predictRaw(features) - tmpData = tmpData.withColumn(map(rawPredictionCol), - callUDF(features2raw, new VectorUDT, col(map(featuresCol)))) + tmpData = tmpData.withColumn($(rawPredictionCol), + callUDF(features2raw, new VectorUDT, col($(featuresCol)))) numColsOutput += 1 } - if (map(probabilityCol) != "") { - if (map(rawPredictionCol) != "") { + if ($(probabilityCol) != "") { + if ($(rawPredictionCol) != "") { val raw2prob = udf { (rawPreds: Vector) => val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) Vectors.dense(1.0 - prob1, prob1): Vector } - tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol)))) + tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol)))) } else { val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } - tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol)))) + tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol)))) } numColsOutput += 1 } - if (map(predictionCol) != "") { - val t = map(threshold) - if (map(probabilityCol) != "") { + if ($(predictionCol) != "") { + val t = $(threshold) + if ($(probabilityCol) != "") { val predict = udf { probs: Vector => if (probs(1) > t) 1.0 else 0.0 } - tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol)))) - } else if (map(rawPredictionCol) != "") { + tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol)))) + } else if ($(rawPredictionCol) != "") { val predict = udf { rawPreds: Vector => val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) if (prob1 > t) 1.0 else 0.0 } - tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol)))) + tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol)))) } else { val predict = udf { features: Vector => this.predict(features) } - tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol)))) + tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol)))) } numColsOutput += 1 } @@ -193,9 +189,7 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(0.0, m) } - override protected def copy(): LogisticRegressionModel = { - val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): LogisticRegressionModel = { + copyValues(new LogisticRegressionModel(parent, weights, intercept), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 10404548ccfde..8519841c5c26c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -34,12 +33,10 @@ private[classification] trait ProbabilisticClassifierParams override protected def validateAndTransformSchema( schema: StructType, - paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = extractParamMap(paramMap) - SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) + val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + SchemaUtils.appendColumn(parentSchema, $(probabilityCol), new VectorUDT) } } @@ -95,36 +92,22 @@ private[spark] abstract class ProbabilisticClassificationModel[ * - probability of each class as [[probabilityCol]] of type [[Vector]]. * * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + override def transform(dataset: DataFrame): DataFrame = { // This default implementation should be overridden as needed. // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - - // Prepare model - val tmpModel = if (paramMap.size != 0) { - val tmpModel = this.copy() - Params.inheritValues(paramMap, parent, tmpModel) - tmpModel - } else { - this - } + transformSchema(dataset.schema, logging = true) val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) // Output selected columns only. - if (map(probabilityCol) != "") { + if ($(probabilityCol) != "") { // output probabilities - val features2probs: FeaturesType => Vector = (features) => { - tmpModel.predictProbabilities(features) - } - outputData.withColumn(map(probabilityCol), - callUDF(features2probs, new VectorUDT, col(map(featuresCol)))) + outputData.withColumn($(probabilityCol), + callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol)))) } else { if (numColsOutput == 0) { this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index cfd6508fce890..17f59bb42e129 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -22,18 +22,17 @@ import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -81,24 +80,22 @@ final class RandomForestClassifier override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): RandomForestClassificationModel = { + override protected def train(dataset: DataFrame): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + - s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + " specified. See StringIndexer.") // TODO: Automatically index labels: SPARK-7126 } - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) val oldModel = OldRandomForest.trainClassifier( oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) } } @@ -123,7 +120,6 @@ object RandomForestClassifier { @AlphaComponent final class RandomForestClassificationModel private[ml] ( override val parent: RandomForestClassifier, - override val fittingParamMap: ParamMap, private val _trees: Array[DecisionTreeClassificationModel]) extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -150,10 +146,8 @@ final class RandomForestClassificationModel private[ml] ( votes.maxBy(_._2)._1 } - override protected def copy(): RandomForestClassificationModel = { - val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): RandomForestClassificationModel = { + copyValues(new RandomForestClassificationModel(parent, _trees), extra) } override def toString: String = { @@ -172,14 +166,13 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => // parent, fittingParamMap for each tree is null since there are no good ways to set these. - DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures) + DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestClassificationModel(parent, fittingParamMap, newTrees) + new RandomForestClassificationModel(parent, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index c865eb9fe092d..e5a73c6087a11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.types.DoubleType * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent -class BinaryClassificationEvaluator extends Evaluator with Params - with HasRawPredictionCol with HasLabelCol { +class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol { /** * param for metric name in evaluation @@ -44,7 +43,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = getOrDefault(metricName) + def getMetricName: String = $(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -57,20 +56,18 @@ class BinaryClassificationEvaluator extends Evaluator with Params setDefault(metricName -> "areaUnderROC") - override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = extractParamMap(paramMap) - + override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) .map { case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val metric = map(metricName) match { + val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() case "areaUnderPR" => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index f3ce6dfca2c1c..6eb1db6971111 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -44,7 +44,7 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol { new DoubleParam(this, "threshold", "threshold used to binarize continuous features") /** @group getParam */ - def getThreshold: Double = getOrDefault(threshold) + def getThreshold: Double = $(threshold) /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -57,23 +57,21 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol { /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val td = map(threshold) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val td = $(threshold) val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } - val outputColName = map(outputCol) + val outputColName = $(outputCol) val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() dataset.select(col("*"), - binarizer(col(map(inputCol))).as(outputColName, metadata)) + binarizer(col($(inputCol))).as(outputColName, metadata)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType) + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields - val outputColName = map(outputCol) + val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0b3128f9ee8cd..c305a819a8966 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamValidators} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** @@ -42,13 +42,13 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { setDefault(numFeatures -> (1 << 18)) /** @group getParam */ - def getNumFeatures: Int = getOrDefault(numFeatures) + def getNumFeatures: Int = $(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { - val hashingTF = new feature.HashingTF(paramMap(numFeatures)) + override protected def createTransformFunc: Iterable[_] => Vector = { + val hashingTF = new feature.HashingTF($(numFeatures)) hashingTF.transform } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index e6a62d998bb97..d901a20aed002 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -43,7 +43,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol setDefault(minDocFreq -> 0) /** @group getParam */ - def getMinDocFreq: Int = getOrDefault(minDocFreq) + def getMinDocFreq: Int = $(minDocFreq) /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) @@ -51,10 +51,9 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** * Validate and transform the input schema. */ - protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT) - SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } @@ -71,18 +70,15 @@ final class IDF extends Estimator[IDFModel] with IDFBase { /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val idf = new feature.IDF(map(minDocFreq)).fit(input) - val model = new IDFModel(this, map, idf) - Params.inheritValues(map, this, model) - model + override def fit(dataset: DataFrame): IDFModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF($(minDocFreq)).fit(input) + copyValues(new IDFModel(this, idf)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } @@ -93,7 +89,6 @@ final class IDF extends Estimator[IDFModel] with IDFBase { @AlphaComponent class IDFModel private[ml] ( override val parent: IDF, - override val fittingParamMap: ParamMap, idfModel: feature.IDFModel) extends Model[IDFModel] with IDFBase { @@ -103,14 +98,13 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } - dataset.withColumn(map(outputCol), idf(col(map(inputCol)))) + dataset.withColumn($(outputCol), idf(col($(inputCol)))) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index bd2b5f6067e2d..755b46a64c7f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap} +import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** @@ -41,13 +41,13 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { setDefault(p -> 2.0) /** @group getParam */ - def getP: Double = getOrDefault(p) + def getP: Double = $(p) /** @group setParam */ def setP(value: Double): this.type = set(p, value) - override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { - val normalizer = new feature.Normalizer(paramMap(p)) + override protected def createTransformFunc: Vector => Vector = { + val normalizer = new feature.Normalizer($(p)) normalizer.transform } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 1b7c939c2dffe..63e190c8aae53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamValidators} import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -47,14 +47,13 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp setDefault(degree -> 2) /** @group getParam */ - def getDegree: Int = getOrDefault(degree) + def getDegree: Int = $(degree) /** @group setParam */ def setDegree(value: Int): this.type = set(degree, value) - override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v => - val d = paramMap(degree) - PolynomialExpansion.expand(v, d) + override protected def createTransformFunc: Vector => Vector = { v => + PolynomialExpansion.expand(v, $(degree)) } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index a0e9ed32e0e4c..7cad59ff3fa37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -71,25 +71,21 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd)) + override def fit(dataset: DataFrame): StandardScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - val model = new StandardScalerModel(this, map, scalerModel) - Params.inheritValues(map, this, model) - model + copyValues(new StandardScalerModel(this, scalerModel)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], - s"Input column ${map(inputCol)} must be a vector column") - require(!schema.fieldNames.contains(map(outputCol)), - s"Output column ${map(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } } @@ -101,7 +97,6 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP @AlphaComponent class StandardScalerModel private[ml] ( override val parent: StandardScaler, - override val fittingParamMap: ParamMap, scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { @@ -111,21 +106,19 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) - dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val scale = udf { scaler.transform _ } + dataset.withColumn($(outputCol), scale(col($(inputCol)))) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], - s"Input column ${map(inputCol)} must be a vector column") - require(!schema.fieldNames.contains(map(outputCol)), - s"Output column ${map(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9db3b29e10d69..3d78537ad84cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -34,18 +34,17 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { /** Validates and transforms the input schema. */ - protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val inputColName = map(inputCol) + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") val inputFields = schema.fields - val outputColName = map(outputCol) + val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } @@ -69,19 +68,16 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase // TODO: handle unseen labels - override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { - val map = extractParamMap(paramMap) - val counts = dataset.select(col(map(inputCol)).cast(StringType)) + override def fit(dataset: DataFrame): StringIndexerModel = { + val counts = dataset.select(col($(inputCol)).cast(StringType)) .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray - val model = new StringIndexerModel(this, map, labels) - Params.inheritValues(map, this, model) - model + copyValues(new StringIndexerModel(this, labels)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } @@ -92,7 +88,6 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase @AlphaComponent class StringIndexerModel private[ml] ( override val parent: StringIndexer, - override val fittingParamMap: ParamMap, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { private val labelToIndex: OpenHashMap[String, Double] = { @@ -112,8 +107,7 @@ class StringIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = extractParamMap(paramMap) + override def transform(dataset: DataFrame): DataFrame = { val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -122,14 +116,14 @@ class StringIndexerModel private[ml] ( throw new SparkException(s"Unseen label: $label.") } } - val outputColName = map(outputCol) + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() dataset.select(col("*"), - indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) + indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 01752ba482d0c..2863b7621526e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.sql.types.{DataType, StringType, ArrayType} +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** * :: AlphaComponent :: @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} @AlphaComponent class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { - override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") } @@ -62,7 +62,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = getOrDefault(minTokenLength) + def getMinTokenLength: Int = $(minTokenLength) /** * Indicates whether regex splits on gaps (true) or matching tokens (false). @@ -75,7 +75,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = getOrDefault(gaps) + def getGaps: Boolean = $(gaps) /** * Regex pattern used by tokenizer. @@ -88,14 +88,14 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = getOrDefault(pattern) + def getPattern: String = $(pattern) setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") - override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => - val re = paramMap(pattern).r - val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq - val minLength = paramMap(minTokenLength) + override protected def createTransformFunc: String => Seq[String] = { str => + val re = $(pattern).r + val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 5e781a326d98c..8f2e62a8e2081 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} @@ -42,13 +41,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = extractParamMap(paramMap) + override def transform(dataset: DataFrame): DataFrame = { val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } val schema = dataset.schema - val inputColNames = map(inputCols) + val inputColNames = $(inputCols) val args = inputColNames.map { c => schema(c).dataType match { case DoubleType => dataset(c) @@ -56,13 +54,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol))) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - val inputColNames = map(inputCols) - val outputColName = map(outputCol) + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { case _: NumericType | BooleanType => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index ed833c63c7ef1..07ea579d69893 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,19 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, - Attribute, AttributeGroup} -import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.callUDF import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet - /** Private trait for params for VectorIndexer and VectorIndexerModel */ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { @@ -49,7 +47,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu setDefault(maxCategories -> 20) /** @group getParam */ - def getMaxCategories: Int = getOrDefault(maxCategories) + def getMaxCategories: Int = $(maxCategories) } /** @@ -100,33 +98,29 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val firstRow = dataset.select(map(inputCol)).take(1) + override def fit(dataset: DataFrame): VectorIndexerModel = { + transformSchema(dataset.schema, logging = true) + val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size - val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val maxCats = map(maxCategories) + val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) iter.foreach(localCatStats.addVector) Iterator(localCatStats) }.reduce((stats1, stats2) => stats1.merge(stats2)) - val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) - Params.inheritValues(map, this, model) - model + copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType): StructType = { // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). - val map = extractParamMap(paramMap) val dataType = new VectorUDT - require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") - require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") - SchemaUtils.checkColumnType(schema, map(inputCol), dataType) - SchemaUtils.appendColumn(schema, map(outputCol), dataType) + require(isDefined(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(isDefined(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, $(inputCol), dataType) + SchemaUtils.appendColumn(schema, $(outputCol), dataType) } } @@ -243,7 +237,6 @@ private object VectorIndexer { @AlphaComponent class VectorIndexerModel private[ml] ( override val parent: VectorIndexer, - override val fittingParamMap: ParamMap, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) extends Model[VectorIndexerModel] with VectorIndexerParams { @@ -326,35 +319,33 @@ class VectorIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val newField = prepOutputField(dataset.schema, map) - val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) - dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val newField = prepOutputField(dataset.schema) + val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol))) + dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) + override def transformSchema(schema: StructType): StructType = { val dataType = new VectorUDT - require(map.contains(inputCol), + require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") - require(map.contains(outputCol), + require(isDefined(outputCol), s"VectorIndexerModel requires output column parameter: $outputCol") - SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.checkColumnType(schema, $(inputCol), dataType) // If the input metadata specifies numFeatures, compare with expected numFeatures. - val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { Some(origAttrGroup.attributes.get.length) } else { origAttrGroup.numAttributes } require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + - s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" $numFeatures features, but input column ${$(inputCol)} had metadata specifying" + s" ${origAttrGroup.numAttributes.get} features.") - val newField = prepOutputField(schema, map) + val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField StructType(outputFields) } @@ -362,11 +353,10 @@ class VectorIndexerModel private[ml] ( /** * Prepare the output column field, including per-feature metadata. * @param schema Input schema - * @param map Parameter map (with this class' embedded parameter map folded in) * @return Output column field. This field does not contain non-ML metadata. */ - private def prepOutputField(schema: StructType, map: ParamMap): StructField = { - val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + private def prepOutputField(schema: StructType): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { // Convert original attributes to modified attributes val origAttrs: Array[Attribute] = origAttrGroup.attributes.get @@ -389,7 +379,7 @@ class VectorIndexerModel private[ml] ( } else { partialFeatureAttributes } - val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) newAttributeGroup.toStructField() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 0163fa8bd8a5b..34ff92970129f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -18,16 +18,16 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row} /** * Params for [[Word2Vec]] and [[Word2VecModel]]. @@ -43,7 +43,7 @@ private[feature] trait Word2VecBase extends Params setDefault(vectorSize -> 100) /** @group getParam */ - def getVectorSize: Int = getOrDefault(vectorSize) + def getVectorSize: Int = $(vectorSize) /** * Number of partitions for sentences of words. @@ -53,7 +53,7 @@ private[feature] trait Word2VecBase extends Params setDefault(numPartitions -> 1) /** @group getParam */ - def getNumPartitions: Int = getOrDefault(numPartitions) + def getNumPartitions: Int = $(numPartitions) /** * The minimum number of times a token must appear to be included in the word2vec model's @@ -64,7 +64,7 @@ private[feature] trait Word2VecBase extends Params setDefault(minCount -> 5) /** @group getParam */ - def getMinCount: Int = getOrDefault(minCount) + def getMinCount: Int = $(minCount) setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) @@ -73,10 +73,9 @@ private[feature] trait Word2VecBase extends Params /** * Validate and transform the input schema. */ - protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true)) - SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } @@ -112,25 +111,22 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v } + override def fit(dataset: DataFrame): Word2VecModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() - .setLearningRate(map(stepSize)) - .setMinCount(map(minCount)) - .setNumIterations(map(maxIter)) - .setNumPartitions(map(numPartitions)) - .setSeed(map(seed)) - .setVectorSize(map(vectorSize)) + .setLearningRate($(stepSize)) + .setMinCount($(minCount)) + .setNumIterations($(maxIter)) + .setNumPartitions($(numPartitions)) + .setSeed($(seed)) + .setVectorSize($(vectorSize)) .fit(input) - val model = new Word2VecModel(this, map, wordVectors) - Params.inheritValues(map, this, model) - model + copyValues(new Word2VecModel(this, wordVectors)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } @@ -141,7 +137,6 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { @AlphaComponent class Word2VecModel private[ml] ( override val parent: Word2Vec, - override val fittingParamMap: ParamMap, wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { @@ -155,15 +150,14 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros(map(vectorSize)) + val cum = Vectors.zeros($(vectorSize)) val model = bWordVectors.value.getVectors for (word <- sentence) { if (model.contains(word)) { @@ -176,10 +170,10 @@ class Word2VecModel private[ml] ( cum } } - dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol)))) + dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index 195333a5cc47f..e8b3628140e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,18 +18,17 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} - /** * :: DeveloperApi :: * @@ -44,7 +43,6 @@ private[spark] trait PredictorParams extends Params /** * Validates and transforms the input schema with the provided param map. * @param schema input schema - * @param paramMap additional parameters * @param fitting whether this is in fitting * @param featuresDataType SQL DataType for FeaturesType. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -52,17 +50,15 @@ private[spark] trait PredictorParams extends Params */ protected def validateAndTransformSchema( schema: StructType, - paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) } - SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } } @@ -96,14 +92,15 @@ private[spark] abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: DataFrame, paramMap: ParamMap): M = { + override def fit(dataset: DataFrame): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - val model = train(dataset, map) - Params.inheritValues(map, this, model) // copy params to model - model + transformSchema(dataset.schema, logging = true) + copyValues(train(dataset)) + } + + override def copy(extra: ParamMap): Learner = { + super.copy(extra).asInstanceOf[Learner] } /** @@ -114,12 +111,10 @@ private[spark] abstract class Predictor[ * and copying parameters into the model. * * @param dataset Training dataset - * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already - * been combined with the embedded ParamMap. * @return Fitted model */ @DeveloperApi - protected def train(dataset: DataFrame, paramMap: ParamMap): M + protected def train(dataset: DataFrame): M /** * :: DeveloperApi :: @@ -134,17 +129,16 @@ private[spark] abstract class Predictor[ @DeveloperApi protected def featuresDataType: DataType = new VectorUDT - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, featuresDataType) } /** * Extract [[labelCol]] and [[featuresCol]] from the given dataset, * and put it in an RDD with strong types. */ - protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = extractParamMap(paramMap) - dataset.select(map(labelCol), map(featuresCol)) + protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + dataset.select($(labelCol), $(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -186,8 +180,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel @DeveloperApi protected def featuresDataType: DataType = new VectorUDT - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, featuresDataType) } /** @@ -195,30 +189,16 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel * the predictions as a new column [[predictionCol]]. * * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset with [[predictionCol]] of type [[Double]] */ - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + override def transform(dataset: DataFrame): DataFrame = { // This default implementation should be overridden as needed. // Check schema - transformSchema(dataset.schema, paramMap, logging = true) - val map = extractParamMap(paramMap) - - // Prepare model - val tmpModel = if (paramMap.size != 0) { - val tmpModel = this.copy() - Params.inheritValues(paramMap, parent, tmpModel) - tmpModel - } else { - this - } + transformSchema(dataset.schema, logging = true) - if (map(predictionCol) != "") { - val pred: FeaturesType => Double = (features) => { - tmpModel.predict(features) - } - dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol)))) + if ($(predictionCol) != "") { + dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") @@ -234,10 +214,4 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel */ @DeveloperApi protected def predict(features: FeaturesType): Double - - /** - * Create a copy of the model. - * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - */ - protected def copy(): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index fb770622e71f0..0e225627d4ee3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -20,14 +20,11 @@ package org.apache.spark.ml.impl.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.impl.estimator.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, - BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, - Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} - /** * :: DeveloperApi :: * Parameters for Decision Tree-based algorithms. @@ -123,43 +120,43 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group getParam */ - final def getMaxDepth: Int = getOrDefault(maxDepth) + final def getMaxDepth: Int = $(maxDepth) /** @group setParam */ def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group getParam */ - final def getMaxBins: Int = getOrDefault(maxBins) + final def getMaxBins: Int = $(maxBins) /** @group setParam */ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group getParam */ - final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + final def getMinInstancesPerNode: Int = $(minInstancesPerNode) /** @group setParam */ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group getParam */ - final def getMinInfoGain: Double = getOrDefault(minInfoGain) + final def getMinInfoGain: Double = $(minInfoGain) /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertGetParam */ - final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + final def getMaxMemoryInMB: Int = $(maxMemoryInMB) /** @group expertSetParam */ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** @group expertGetParam */ - final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + final def getCacheNodeIds: Boolean = $(cacheNodeIds) /** @group expertSetParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group expertGetParam */ - final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + final def getCheckpointInterval: Int = $(checkpointInterval) /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( @@ -206,7 +203,7 @@ private[ml] trait TreeClassifierParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -248,7 +245,7 @@ private[ml] trait TreeRegressorParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -291,7 +288,7 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group getParam */ - final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) + final def getSubsamplingRate: Double = $(subsamplingRate) /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) @@ -364,13 +361,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ - final def getNumTrees: Int = getOrDefault(numTrees) + final def getNumTrees: Int = $(numTrees) /** @group setParam */ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase } private[ml] object RandomForestParams { @@ -418,7 +415,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { def setStepSize(value: Double): this.type = set(stepSize, value) /** @group getParam */ - final def getStepSize: Double = getOrDefault(stepSize) + final def getStepSize: Double = $(stepSize) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index df6360dce6013..51ce19d29cd29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -23,7 +23,7 @@ import java.util.NoSuchElementException import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.util.Identifiable /** @@ -49,7 +49,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal * Assert that the given value is valid for this parameter. * * Note: Parameter checks involving interactions between multiple parameters should be - * implemented in [[Params.validate()]]. Checks for input/output columns should be + * implemented in [[Params.validateParams()]]. Checks for input/output columns should be * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. * * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters @@ -258,7 +258,9 @@ trait Params extends Identifiable with Serializable { * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. */ - def validate(paramMap: ParamMap): Unit = { } + def validateParams(paramMap: ParamMap): Unit = { + copy(paramMap).validateParams() + } /** * Validates parameter values stored internally. @@ -269,7 +271,11 @@ trait Params extends Identifiable with Serializable { * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. */ - def validate(): Unit = validate(ParamMap.empty) + def validateParams(): Unit = { + params.filter(isDefined _).foreach { param => + param.asInstanceOf[Param[Any]].validate($(param)) + } + } /** * Returns the documentation of all params. @@ -288,6 +294,11 @@ trait Params extends Identifiable with Serializable { defaultParamMap.contains(param) || paramMap.contains(param) } + /** Tests whether this instance contains a param with a given name. */ + def hasParam(paramName: String): Boolean = { + params.exists(_.name == paramName) + } + /** Gets a param by its name. */ def getParam(paramName: String): Param[Any] = { params.find(_.name == paramName).getOrElse { @@ -337,6 +348,9 @@ trait Params extends Identifiable with Serializable { get(param).orElse(getDefault(param)).get } + /** An alias for [[getOrDefault()]]. */ + protected final def $[T](param: Param[T]): T = getOrDefault(param) + /** * Sets a default value for a param. * @param param param to set the default value. Make sure that this param is initialized before @@ -382,19 +396,31 @@ trait Params extends Identifiable with Serializable { defaultParamMap.contains(param) } + /** + * Creates a copy of this instance with a randomly generated uid and some extra params. + * The default implementation calls the default constructor to create a new instance, then + * copies the embedded and extra parameters over and returns the new instance. + * Subclasses should override this method if the default approach is not sufficient. + */ + def copy(extra: ParamMap): Params = { + val that = this.getClass.newInstance() + copyValues(that, extra) + that + } + /** * Extracts the embedded default param values and user-supplied values, and then merges them with * extra values from input into a flat param map, where the latter value is used if there exist * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. */ - protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + final def extractParamMap(extraParamMap: ParamMap): ParamMap = { defaultParamMap ++ paramMap ++ extraParamMap } /** * [[extractParamMap]] with no extra values. */ - protected final def extractParamMap(): ParamMap = { + final def extractParamMap(): ParamMap = { extractParamMap(ParamMap.empty) } @@ -408,34 +434,21 @@ trait Params extends Identifiable with Serializable { private def shouldOwn(param: Param[_]): Unit = { require(param.parent.eq(this), s"Param $param does not belong to $this.") } -} -/** - * :: DeveloperApi :: - * - * Helper functionality for developers. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@DeveloperApi -private[spark] object Params { - - /** - * Copies parameter values from the parent estimator to the child model it produced. - * @param paramMap the param map that holds parameters of the parent - * @param parent the parent estimator - * @param child the child model - */ - def inheritValues[E <: Params, M <: E]( - paramMap: ParamMap, - parent: E, - child: M): Unit = { - val childParams = child.params.map(_.name).toSet - parent.params.foreach { param => - if (paramMap.contains(param) && childParams.contains(param.name)) { - child.set(child.getParam(param.name), paramMap(param)) + /** + * Copies param values from this instance to another instance for params shared by them. + * @param to the target instance + * @param extra extra params to be copied + * @return the target instance with param values copied + */ + protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { + val map = extractParamMap(extra) + params.foreach { param => + if (map.contains(param) && to.hasParam(param.name)) { + to.set(param.name, map(param)) } } + to } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 7da4bb4b4bf25..d379172e0bf53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -21,8 +21,6 @@ import java.io.PrintWriter import scala.reflect.ClassTag -import org.apache.spark.ml.param.ParamValidators - /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with * {{{ @@ -142,7 +140,7 @@ private[shared] object SharedParamsCodeGen { | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ - | final def get$Name: $T = getOrDefault($name) + | final def get$Name: $T = $$($name) |} |""".stripMargin } @@ -169,7 +167,6 @@ private[shared] object SharedParamsCodeGen { | |package org.apache.spark.ml.param.shared | - |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ |import org.apache.spark.util.Utils | diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e1549f46a68d4..fb1874ccfc8dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.param.shared -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.util.Utils @@ -37,7 +36,7 @@ private[ml] trait HasRegParam extends Params { final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ - final def getRegParam: Double = getOrDefault(regParam) + final def getRegParam: Double = $(regParam) } /** @@ -52,7 +51,7 @@ private[ml] trait HasMaxIter extends Params { final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ - final def getMaxIter: Int = getOrDefault(maxIter) + final def getMaxIter: Int = $(maxIter) } /** @@ -69,7 +68,7 @@ private[ml] trait HasFeaturesCol extends Params { setDefault(featuresCol, "features") /** @group getParam */ - final def getFeaturesCol: String = getOrDefault(featuresCol) + final def getFeaturesCol: String = $(featuresCol) } /** @@ -86,7 +85,7 @@ private[ml] trait HasLabelCol extends Params { setDefault(labelCol, "label") /** @group getParam */ - final def getLabelCol: String = getOrDefault(labelCol) + final def getLabelCol: String = $(labelCol) } /** @@ -103,7 +102,7 @@ private[ml] trait HasPredictionCol extends Params { setDefault(predictionCol, "prediction") /** @group getParam */ - final def getPredictionCol: String = getOrDefault(predictionCol) + final def getPredictionCol: String = $(predictionCol) } /** @@ -120,7 +119,7 @@ private[ml] trait HasRawPredictionCol extends Params { setDefault(rawPredictionCol, "rawPrediction") /** @group getParam */ - final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) + final def getRawPredictionCol: String = $(rawPredictionCol) } /** @@ -137,7 +136,7 @@ private[ml] trait HasProbabilityCol extends Params { setDefault(probabilityCol, "probability") /** @group getParam */ - final def getProbabilityCol: String = getOrDefault(probabilityCol) + final def getProbabilityCol: String = $(probabilityCol) } /** @@ -152,7 +151,7 @@ private[ml] trait HasThreshold extends Params { final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ - final def getThreshold: Double = getOrDefault(threshold) + final def getThreshold: Double = $(threshold) } /** @@ -167,7 +166,7 @@ private[ml] trait HasInputCol extends Params { final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") /** @group getParam */ - final def getInputCol: String = getOrDefault(inputCol) + final def getInputCol: String = $(inputCol) } /** @@ -182,7 +181,7 @@ private[ml] trait HasInputCols extends Params { final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") /** @group getParam */ - final def getInputCols: Array[String] = getOrDefault(inputCols) + final def getInputCols: Array[String] = $(inputCols) } /** @@ -197,7 +196,7 @@ private[ml] trait HasOutputCol extends Params { final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") /** @group getParam */ - final def getOutputCol: String = getOrDefault(outputCol) + final def getOutputCol: String = $(outputCol) } /** @@ -212,7 +211,7 @@ private[ml] trait HasCheckpointInterval extends Params { final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) /** @group getParam */ - final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + final def getCheckpointInterval: Int = $(checkpointInterval) } /** @@ -229,7 +228,7 @@ private[ml] trait HasFitIntercept extends Params { setDefault(fitIntercept, true) /** @group getParam */ - final def getFitIntercept: Boolean = getOrDefault(fitIntercept) + final def getFitIntercept: Boolean = $(fitIntercept) } /** @@ -246,7 +245,7 @@ private[ml] trait HasSeed extends Params { setDefault(seed, Utils.random.nextLong()) /** @group getParam */ - final def getSeed: Long = getOrDefault(seed) + final def getSeed: Long = $(seed) } /** @@ -261,7 +260,7 @@ private[ml] trait HasElasticNetParam extends Params { final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) /** @group getParam */ - final def getElasticNetParam: Double = getOrDefault(elasticNetParam) + final def getElasticNetParam: Double = $(elasticNetParam) } /** @@ -276,7 +275,7 @@ private[ml] trait HasTol extends Params { final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") /** @group getParam */ - final def getTol: Double = getOrDefault(tol) + final def getTol: Double = $(tol) } /** @@ -291,6 +290,6 @@ private[ml] trait HasStepSize extends Params { final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") /** @group getParam */ - final def getStepSize: Double = getOrDefault(stepSize) + final def getStepSize: Double = $(stepSize) } // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index f9f2b2764ddb1..6cf4b40075281 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -59,7 +59,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1)) /** @group getParam */ - def getRank: Int = getOrDefault(rank) + def getRank: Int = $(rank) /** * Param for number of user blocks (>= 1). @@ -70,7 +70,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR ParamValidators.gtEq(1)) /** @group getParam */ - def getNumUserBlocks: Int = getOrDefault(numUserBlocks) + def getNumUserBlocks: Int = $(numUserBlocks) /** * Param for number of item blocks (>= 1). @@ -81,7 +81,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR ParamValidators.gtEq(1)) /** @group getParam */ - def getNumItemBlocks: Int = getOrDefault(numItemBlocks) + def getNumItemBlocks: Int = $(numItemBlocks) /** * Param to decide whether to use implicit preference. @@ -91,7 +91,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) + def getImplicitPrefs: Boolean = $(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation (>= 0). @@ -102,7 +102,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR ParamValidators.gtEq(0)) /** @group getParam */ - def getAlpha: Double = getOrDefault(alpha) + def getAlpha: Double = $(alpha) /** * Param for the column name for user ids. @@ -112,7 +112,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = getOrDefault(userCol) + def getUserCol: String = $(userCol) /** * Param for the column name for item ids. @@ -122,7 +122,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = getOrDefault(itemCol) + def getItemCol: String = $(itemCol) /** * Param for the column name for ratings. @@ -132,7 +132,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = getOrDefault(ratingCol) + def getRatingCol: String = $(ratingCol) /** * Param for whether to apply nonnegativity constraints. @@ -143,7 +143,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - def getNonnegative: Boolean = getOrDefault(nonnegative) + def getNonnegative: Boolean = $(nonnegative) setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", @@ -152,19 +152,17 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Validates and transforms the input schema. * @param schema input schema - * @param paramMap extra params * @return output schema */ - protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - assert(schema(map(userCol)).dataType == IntegerType) - assert(schema(map(itemCol)).dataType== IntegerType) - val ratingType = schema(map(ratingCol)).dataType - assert(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = map(predictionCol) - assert(!schema.fieldNames.contains(predictionColName), + protected def validateAndTransformSchema(schema: StructType): StructType = { + require(schema($(userCol)).dataType == IntegerType) + require(schema($(itemCol)).dataType== IntegerType) + val ratingType = schema($(ratingCol)).dataType + require(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = $(predictionCol) + require(!schema.fieldNames.contains(predictionColName), s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) StructType(newFields) } } @@ -174,7 +172,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR */ class ALSModel private[ml] ( override val parent: ALS, - override val fittingParamMap: ParamMap, k: Int, userFactors: RDD[(Int, Array[Float])], itemFactors: RDD[(Int, Array[Float])]) @@ -183,9 +180,8 @@ class ALSModel private[ml] ( /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + override def transform(dataset: DataFrame): DataFrame = { import dataset.sqlContext.implicits._ - val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -199,13 +195,13 @@ class ALSModel private[ml] ( } } dataset - .join(users, dataset(map(userCol)) === users("id"), "left") - .join(items, dataset(map(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol))) + .join(users, dataset($(userCol)) === users("id"), "left") + .join(items, dataset($(itemCol)) === items("id"), "left") + .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } @@ -292,25 +288,22 @@ class ALS extends Estimator[ALSModel] with ALSParams { this } - override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = extractParamMap(paramMap) + override def fit(dataset: DataFrame): ALSModel = { val ratings = dataset - .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType)) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), - numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), - maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), - alpha = map(alpha), nonnegative = map(nonnegative), - checkpointInterval = map(checkpointInterval)) - val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) - Params.inheritValues(map, this, model) - model + val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), + numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), + maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), + alpha = $(alpha), nonnegative = $(nonnegative), + checkpointInterval = $(checkpointInterval)) + copyValues(new ALSModel(this, $(rank), userFactors, itemFactors)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 756725a64b0f3..b07c26fe79b36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector @@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -63,15 +62,13 @@ final class DecisionTreeRegressor override def setImpurity(value: String): this.type = super.setImpurity(value) - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): DecisionTreeRegressionModel = { + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val oldModel = OldDecisionTree.train(oldDataset, strategy) - DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures) } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -96,7 +93,6 @@ object DecisionTreeRegressor { @AlphaComponent final class DecisionTreeRegressionModel private[ml] ( override val parent: DecisionTreeRegressor, - override val fittingParamMap: ParamMap, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -108,10 +104,8 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predict(features) } - override protected def copy(): DecisionTreeRegressionModel = { - val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): DecisionTreeRegressionModel = { + copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra) } override def toString: String = { @@ -130,12 +124,11 @@ private[ml] object DecisionTreeRegressionModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + new DecisionTreeRegressionModel(parent, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 76c98376930c5..bc796958e4545 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -23,20 +23,18 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.param.{Params, ParamMap, Param} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -111,7 +109,7 @@ final class GBTRegressor def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { @@ -124,16 +122,14 @@ final class GBTRegressor } } - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): GBTRegressionModel = { + override protected def train(dataset: DataFrame): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) } } @@ -155,7 +151,6 @@ object GBTRegressor { @AlphaComponent final class GBTRegressionModel( override val parent: GBTRegressor, - override val fittingParamMap: ParamMap, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTRegressionModel] @@ -178,10 +173,8 @@ final class GBTRegressionModel( if (prediction > 0.0) 1.0 else 0.0 } - override protected def copy(): GBTRegressionModel = { - val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): GBTRegressionModel = { + copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra) } override def toString: String = { @@ -200,14 +193,13 @@ private[ml] object GBTRegressionModel { def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => // parent, fittingParamMap for each tree is null since there are no good ways to set these. - DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + new GBTRegressionModel(parent, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 0b81c48466be9..66c475f2d9840 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import scala.collection.mutable -import breeze.linalg.{norm => brzNorm, DenseVector => BDV} -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import breeze.optimize.{CachedDiffFunction, DiffFunction} +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, + OWLQN => BreezeOWLQN} +import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter, HasRegParam} -import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter -import org.apache.spark.Logging /** * Params for linear regression. @@ -96,9 +96,9 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset, paramMap).map { + val instances = extractLabeledPoints(dataset).map { case LabeledPoint(label: Double, features: Vector) => (label, features) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -125,7 +125,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean) + return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean) } val featuresMean = summarizer.mean.toArray @@ -133,17 +133,17 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. - val effectiveRegParam = paramMap(regParam) / yStd - val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam - val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + val effectiveRegParam = $(regParam) / yStd + val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam val costFun = new LeastSquaresCostFun(instances, yStd, yMean, featuresStd, featuresMean, effectiveL2RegParam) - val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { - new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol)) } val initialWeights = Vectors.zeros(numFeatures) @@ -178,7 +178,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. - new LinearRegressionModel(this, paramMap, weights.compressed, intercept) + new LinearRegressionModel(this, weights.compressed, intercept) } } @@ -190,7 +190,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress @AlphaComponent class LinearRegressionModel private[ml] ( override val parent: LinearRegression, - override val fittingParamMap: ParamMap, val weights: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] @@ -200,10 +199,8 @@ class LinearRegressionModel private[ml] ( dot(features, weights) + intercept } - override protected def copy(): LinearRegressionModel = { - val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(extractParamMap(), this, m) - m + override def copy(extra: ParamMap): LinearRegressionModel = { + copyValues(new LinearRegressionModel(parent, weights, intercept), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2171ef3d32c26..0468a1be1ba74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -20,18 +20,17 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} -import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * :: AlphaComponent :: * @@ -77,17 +76,15 @@ final class RandomForestRegressor override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train( - dataset: DataFrame, - paramMap: ParamMap): RandomForestRegressionModel = { + override protected def train(dataset: DataFrame): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) val oldModel = OldRandomForest.trainRegressor( oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) - RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) } } @@ -110,7 +107,6 @@ object RandomForestRegressor { @AlphaComponent final class RandomForestRegressionModel private[ml] ( override val parent: RandomForestRegressor, - override val fittingParamMap: ParamMap, private val _trees: Array[DecisionTreeRegressionModel]) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -132,10 +128,8 @@ final class RandomForestRegressionModel private[ml] ( _trees.map(_.rootNode.predict(features)).sum / numTrees } - override protected def copy(): RandomForestRegressionModel = { - val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees) - Params.inheritValues(this.extractParamMap(), this, m) - m + override def copy(extra: ParamMap): RandomForestRegressionModel = { + copyValues(new RandomForestRegressionModel(parent, _trees), extra) } override def toString: String = { @@ -154,14 +148,13 @@ private[ml] object RandomForestRegressionModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, - fittingParamMap: ParamMap, categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => // parent, fittingParamMap for each tree is null since there are no good ways to set these. - DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent, fittingParamMap, newTrees) + new RandomForestRegressionModel(parent, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index d679085eeafe1..c6b3327db6ad3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index d1ad0893cd044..cee2aa6e85523 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -39,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = getOrDefault(estimator) + def getEstimator: Estimator[_] = $(estimator) /** * param for estimator param maps @@ -49,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) /** * param for the evaluator for selection @@ -58,7 +58,7 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = getOrDefault(evaluator) + def getEvaluator: Evaluator = $(evaluator) /** * Param for number of folds for cross validation. Must be >= 2. @@ -69,7 +69,7 @@ private[ml] trait CrossValidatorParams extends Params { "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) /** @group getParam */ - def getNumFolds: Int = getOrDefault(numFolds) + def getNumFolds: Int = $(numFolds) setDefault(numFolds -> 3) } @@ -95,23 +95,22 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validate(paramMap: ParamMap): Unit = { + override def validateParams(paramMap: ParamMap): Unit = { getEstimatorParamMaps.foreach { eMap => - getEstimator.validate(eMap ++ paramMap) + getEstimator.validateParams(eMap ++ paramMap) } } - override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = extractParamMap(paramMap) + override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema - transformSchema(dataset.schema, paramMap, logging = true) + transformSchema(dataset.schema, logging = true) val sqlCtx = dataset.sqlContext - val est = map(estimator) - val eval = map(evaluator) - val epm = map(estimatorParamMaps) + val est = $(estimator) + val eval = $(evaluator) + val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -121,27 +120,24 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP trainingDataset.unpersist() var i = 0 while (i < numModels) { - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } validationDataset.unpersist() } - f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - val cvModel = new CrossValidatorModel(this, map, bestModel) - Params.inheritValues(map, this, cvModel) - cvModel + copyValues(new CrossValidatorModel(this, bestModel)) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractParamMap(paramMap) - map(estimator).transformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + $(estimator).transformSchema(schema) } } @@ -152,19 +148,18 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP @AlphaComponent class CrossValidatorModel private[ml] ( override val parent: CrossValidator, - override val fittingParamMap: ParamMap, val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validate(paramMap: ParamMap): Unit = { - bestModel.validate(paramMap) + override def validateParams(paramMap: ParamMap): Unit = { + bestModel.validateParams(paramMap) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - bestModel.transform(dataset, paramMap) + override def transform(dataset: DataFrame): DataFrame = { + bestModel.transform(dataset) } - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - bestModel.transformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + bestModel.transformSchema(schema) } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 3f8e59de0f05c..7e7189a2b1d53 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -84,9 +84,10 @@ public void logisticRegressionWithSetters() { .setThreshold(0.6) .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); - assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); - assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); + LogisticRegression parent = model.parent(); + assert(parent.getMaxIter() == 10); + assert(parent.getRegParam() == 1.0); + assert(parent.getThreshold() == 0.6); assert(model.getThreshold() == 0.6); // Modify model params, and check that the params worked. @@ -109,9 +110,10 @@ public void logisticRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); - assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); - assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); - assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); + LogisticRegression parent2 = model2.parent(); + assert(parent2.getMaxIter() == 5); + assert(parent2.getRegParam() == 0.1); + assert(parent2.getThreshold() == 0.4); assert(model2.getThreshold() == 0.4); assert(model2.getProbabilityCol().equals("theProb")); } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 0cc36c8d56d70..a82b86d560b6e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -23,14 +23,15 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; public class JavaLinearRegressionSuite implements Serializable { @@ -65,8 +66,8 @@ public void linearRegressionDefaultParams() { DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults - assert(model.getFeaturesCol().equals("features")); - assert(model.getPredictionCol().equals("prediction")); + assertEquals("features", model.getFeaturesCol()); + assertEquals("prediction", model.getPredictionCol()); } @Test @@ -76,14 +77,16 @@ public void linearRegressionWithSetters() { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); - assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + LinearRegression parent = model.parent(); + assertEquals(10, parent.getMaxIter()); + assertEquals(1.0, parent.getRegParam(), 0.0); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); - assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); - assert(model2.getPredictionCol().equals("thePred")); + LinearRegression parent2 = model2.parent(); + assertEquals(5, parent2.getMaxIter()); + assertEquals(0.1, parent2.getRegParam(), 0.0); + assertEquals("thePred", model2.getPredictionCol()); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 0bb6b489f2757..08eeca53f0721 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -68,8 +68,8 @@ public void crossValidationWithLogisticRegression() { .setEvaluator(eval) .setNumFolds(3); CrossValidatorModel cvModel = cv.fit(dataset); - ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); - Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); - Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + LogisticRegression parent = (LogisticRegression) cvModel.bestModel().parent(); + Assert.assertEquals(0.001, parent.getRegParam(), 0.0); + Assert.assertEquals(10, parent.getMaxIter()); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2f175fb117941..2b04a3034782e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -42,30 +42,32 @@ class PipelineSuite extends FunSuite { val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] - when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) - when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) + when(model0.copy(any[ParamMap])).thenReturn(model0) + when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) + when(estimator2.copy(any[ParamMap])).thenReturn(estimator2) + when(model2.copy(any[ParamMap])).thenReturn(model2) + when(transformer3.copy(any[ParamMap])).thenReturn(transformer3) + + when(estimator0.fit(meq(dataset0))).thenReturn(model0) + when(model0.transform(meq(dataset0))).thenReturn(dataset1) when(model0.parent).thenReturn(estimator0) - when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) - when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) - when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(transformer1.transform(meq(dataset1))).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2))).thenReturn(model2) + when(model2.transform(meq(dataset2))).thenReturn(dataset3) when(model2.parent).thenReturn(estimator2) - when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + when(transformer3.transform(meq(dataset3))).thenReturn(dataset4) val pipeline = new Pipeline() .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) assert(pipelineModel.stages(2).eq(model2)) assert(pipelineModel.stages(3).eq(transformer3)) - assert(pipelineModel.getModel(estimator0).eq(model0)) - assert(pipelineModel.getModel(estimator2).eq(model2)) - intercept[NoSuchElementException] { - pipelineModel.getModel(mock[Estimator[MyModel]]) - } val output = pipelineModel.transform(dataset0) assert(output.eq(dataset4)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 9b31adecdcb1c..03af4ecd7a7e0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -267,8 +267,8 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, - newTree.fittingParamMap, categoricalFeatures) + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( + oldTree, newTree.parent, categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e6ccc2c93cba8..16c758b82c7cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -129,8 +129,8 @@ private object GBTClassifierSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = GBTClassificationModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 35d8c2e16c6cd..6dd1fdf05514e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -74,9 +74,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setThreshold(0.6) .setProbabilityCol("myProbability") val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter) === Some(10)) - assert(model.fittingParamMap.get(lr.regParam) === Some(1.0)) - assert(model.fittingParamMap.get(lr.threshold) === Some(0.6)) + val parent = model.parent + assert(parent.getMaxIter === 10) + assert(parent.getRegParam === 1.0) + assert(parent.getThreshold === 0.6) assert(model.getThreshold === 0.6) // Modify model params, and check that the params worked. @@ -99,9 +100,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { // Call fit() with new params, and check as many params as we can. val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, lr.probabilityCol -> "theProb") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.fittingParamMap.get(lr.threshold).get === 0.4) + val parent2 = model2.parent + assert(parent2.getMaxIter === 5) + assert(parent2.getRegParam === 0.1) + assert(parent2.getThreshold === 0.4) assert(model2.getThreshold === 0.4) assert(model2.getProbabilityCol == "theProb") } @@ -117,7 +119,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val results = model.transform(dataset) // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().map { + results.select("rawPrediction", "probability").collect().foreach { case Row(raw: Vector, prob: Vector) => assert(raw.size === 2) assert(prob.size === 2) @@ -127,7 +129,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } // Compare prediction with probability - results.select("prediction", "probability").collect().map { + results.select("prediction", "probability").collect().foreach { case Row(pred: Double, prob: Vector) => val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ed41a9664f94f..c41def9330504 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -159,8 +159,8 @@ private object RandomForestClassifierSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = RandomForestClassificationModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index f8852606abbf2..6056e7d3f6ff8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -122,19 +122,21 @@ class ParamsSuite extends FunSuite { assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) + assert(solver.hasParam("inputCol")) + assert(!solver.hasParam("abc")) intercept[NoSuchElementException] { solver.getParam("abc") } intercept[IllegalArgumentException] { - solver.validate() + solver.validateParams() } - solver.validate(ParamMap(inputCol -> "input")) + solver.validateParams(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validate() + solver.validateParams() intercept[IllegalArgumentException] { ParamMap(maxIter -> -10) } @@ -144,6 +146,11 @@ class ParamsSuite extends FunSuite { solver.clearMaxIter() assert(!solver.isSet(maxIter)) + + val copied = solver.copy(ParamMap(solver.maxIter -> 50)) + assert(copied.uid !== solver.uid) + assert(copied.getInputCol === solver.getInputCol) + assert(copied.getMaxIter === 50) } test("ParamValidate") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 6f9c9cb5166ae..dc16073640407 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -23,15 +23,19 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} class TestParams extends Params with HasMaxIter with HasInputCol { def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def setInputCol(value: String): this.type = { set(inputCol, value); this } setDefault(maxIter -> 10) - override def validate(paramMap: ParamMap): Unit = { - val m = extractParamMap(paramMap) - // Note: maxIter is validated when it is set. - require(m.contains(inputCol)) + def clearMaxIter(): this.type = clear(maxIter) + + override def validateParams(): Unit = { + super.validateParams() + require(isDefined(inputCol)) } - def clearMaxIter(): this.type = clear(maxIter) + override def copy(extra: ParamMap): TestParams = { + super.copy(extra).asInstanceOf[TestParams] + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index c87a171b4b229..5aa81b44ddaf9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -84,8 +84,8 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, - newTree.fittingParamMap, categoricalFeatures) + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( + oldTree, newTree.parent, categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 4aec36948ac92..25b36ab08b67c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -130,8 +130,7 @@ private object GBTRegressorSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c6dc1cc29b6ff..45f09f4fdab81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -115,8 +115,8 @@ private object RandomForestRegressorSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = RandomForestRegressionModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 761ea821ef7c6..05313d440fbf6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -49,8 +49,8 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) - val bestParamMap = cvModel.bestModel.fittingParamMap - assert(bestParamMap(lr.regParam) === 0.001) - assert(bestParamMap(lr.maxIter) === 10) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) } } From f32e69ecc333867fc966f65cd0aeaeddd43e0945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Mon, 4 May 2015 12:08:38 -0700 Subject: [PATCH 44/91] [SPARK-7319][SQL] Improve the output from DataFrame.show() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: 云峤 Closes #5865 from kaka1992/df.show and squashes the following commits: c79204b [云峤] Update a1338f6 [云峤] Update python dataFrame show test and add empty df unit test. 734369c [云峤] Update python dataFrame show test and add empty df unit test. 84aec3e [云峤] Update python dataFrame show test and add empty df unit test. 159b3d5 [云峤] update 03ef434 [云峤] update 7394fd5 [云峤] update test show ced487a [云峤] update pep8 b6e690b [云峤] Merge remote-tracking branch 'upstream/master' into df.show 30ac311 [云峤] [SPARK-7294] ADD BETWEEN 7d62368 [云峤] [SPARK-7294] ADD BETWEEN baf839b [云峤] [SPARK-7294] ADD BETWEEN d11d5b9 [云峤] [SPARK-7294] ADD BETWEEN --- R/pkg/R/DataFrame.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 105 ++++++++++++------ .../org/apache/spark/sql/DataFrame.scala | 28 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 19 ++++ 5 files changed, 112 insertions(+), 44 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b59b700af5dc9..841e77e55e0d8 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -167,7 +167,7 @@ setMethod("isLocal", setMethod("showDF", signature(x = "DataFrame"), function(x, numRows = 20) { - cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + callJMethod(x@sdf, "showString", numToInt(numRows)) }) #' show diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index af7a6c582047a..f82e56fdd8278 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -641,7 +641,7 @@ test_that("toJSON() returns an RDD of the correct values", { test_that("showDF()", { df <- jsonFile(sqlCtx, jsonPath) - expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") + expect_output(showDF(df), "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") }) test_that("isLocal()", { diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index aac5b8c4c5770..22762c5bbbcd0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -275,9 +275,12 @@ def show(self, n=20): >>> df DataFrame[age: int, name: string] >>> df.show() - age name - 2 Alice - 5 Bob + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ """ print(self._jdf.showString(n)) @@ -591,12 +594,15 @@ def describe(self, *cols): given, this function computes statistics for all numerical columns. >>> df.describe().show() - summary age - count 2 - mean 3.5 - stddev 1.5 - min 2 - max 5 + +-------+---+ + |summary|age| + +-------+---+ + | count| 2| + | mean|3.5| + | stddev|1.5| + | min| 2| + | max| 5| + +-------+---+ """ jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @@ -801,12 +807,18 @@ def dropna(self, how='any', thresh=None, subset=None): :param subset: optional list of column names to consider. >>> df4.dropna().show() - age height name - 10 80 Alice + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 10| 80|Alice| + +---+------+-----+ >>> df4.na.drop().show() - age height name - 10 80 Alice + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 10| 80|Alice| + +---+------+-----+ """ if how is not None and how not in ['any', 'all']: raise ValueError("how ('" + how + "') should be 'any' or 'all'") @@ -837,25 +849,34 @@ def fillna(self, value, subset=None): then the non-string column is simply ignored. >>> df4.fillna(50).show() - age height name - 10 80 Alice - 5 50 Bob - 50 50 Tom - 50 50 null + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 10| 80|Alice| + | 5| 50| Bob| + | 50| 50| Tom| + | 50| 50| null| + +---+------+-----+ >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() - age height name - 10 80 Alice - 5 null Bob - 50 null Tom - 50 null unknown + +---+------+-------+ + |age|height| name| + +---+------+-------+ + | 10| 80| Alice| + | 5| null| Bob| + | 50| null| Tom| + | 50| null|unknown| + +---+------+-------+ >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() - age height name - 10 80 Alice - 5 null Bob - 50 null Tom - 50 null unknown + +---+------+-------+ + |age|height| name| + +---+------+-------+ + | 10| 80| Alice| + | 5| null| Bob| + | 50| null| Tom| + | 50| null|unknown| + +---+------+-------+ """ if not isinstance(value, (float, int, long, basestring, dict)): raise ValueError("value should be a float, int, long, string, or dict") @@ -1241,11 +1262,17 @@ def getItem(self, key): >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() - l[0] d[key] - 1 value + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ >>> df.select(df.l[0], df.d["key"]).show() - l[0] d[key] - 1 value + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ """ return self[key] @@ -1255,11 +1282,17 @@ def getField(self, name): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - r.b - b + +---+ + |r.b| + +---+ + | b| + +---+ >>> df.select(df.r.a).show() - r.a - 1 + +---+ + |r.a| + +---+ + | 1| + +---+ """ return Column(self._jc.getField(name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c421006c8fd2d..cf344710ff8b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -28,6 +29,7 @@ import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil @@ -175,6 +177,7 @@ class DataFrame private[sql]( * @param numRows Number of rows to show */ private[sql] def showString(numRows: Int): String = { + val sb = new StringBuilder val data = take(numRows) val numCols = schema.fieldNames.length @@ -194,12 +197,25 @@ class DataFrame private[sql]( } } - // Pad the cells - rows.map { row => - row.zipWithIndex.map { case (cell, i) => - String.format(s"%-${colWidths(i)}s", cell) - }.mkString(" ") - }.mkString("\n") + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + StringUtils.leftPad(cell.toString, colWidths(i)) + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + StringUtils.leftPad(cell.toString, colWidths(i)) + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + sb.toString() } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e286fef23caa4..ff31e15e2d472 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -598,6 +598,25 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("SPARK-7319 showString") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |""".stripMargin + assert(testData.select($"*").showString(1) === expectedAnswer) + } + + test("SPARK-7327 show with empty dataFrame") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |""".stripMargin + assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) + } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) From fc8b58195afa67fbb75b4c8303e022f703cbf007 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 4 May 2015 16:21:36 -0700 Subject: [PATCH 45/91] [SPARK-6943] [SPARK-6944] DAG visualization on SparkUI This patch adds the functionality to display the RDD DAG on the SparkUI. This DAG describes the relationships between - an RDD and its dependencies, - an RDD and its operation scopes, and - an RDD's operation scopes and the stage / job hierarchy An operation scope here refers to the existing public APIs that created the RDDs (e.g. `textFile`, `treeAggregate`). In the future, we can expand this to include higher level operations like SQL queries. *Note: This blatantly stole a few lines of HTML and JavaScript from #5547 (thanks shroffpradyumn!)* Here's what the job page looks like: and the stage page: Author: Andrew Or Closes #5729 from andrewor14/viz2 and squashes the following commits: 666c03b [Andrew Or] Round corners of RDD boxes on stage page (minor) 01ba336 [Andrew Or] Change RDD cache color to red (minor) 6f9574a [Andrew Or] Add tests for RDDOperationScope 1c310e4 [Andrew Or] Wrap a few more RDD functions in an operation scope 3ffe566 [Andrew Or] Restore "null" as default for RDD name 5fdd89d [Andrew Or] children -> child (minor) 0d07a84 [Andrew Or] Fix python style afb98e2 [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz2 0d7aa32 [Andrew Or] Fix python tests 3459ab2 [Andrew Or] Fix tests 832443c [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz2 429e9e1 [Andrew Or] Display cached RDDs on the viz b1f0fd1 [Andrew Or] Rename OperatorScope -> RDDOperationScope 31aae06 [Andrew Or] Extract visualization logic from listener 83f9c58 [Andrew Or] Implement a programmatic representation of operator scopes 5a7faf4 [Andrew Or] Rename references to viz scopes to viz clusters ee33d52 [Andrew Or] Separate HTML generating code from listener f9830a2 [Andrew Or] Refactor + clean up + document JS visualization code b80cc52 [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz2 0706992 [Andrew Or] Add link from jobs to stages deb48a0 [Andrew Or] Translate stage boxes taking into account the width 5c7ce16 [Andrew Or] Connect RDDs across stages + update style ab91416 [Andrew Or] Introduce visualization to the Job Page 5f07e9c [Andrew Or] Remove more return statements from scopes 5e388ea [Andrew Or] Fix line too long 43de96e [Andrew Or] Add parent IDs to StageInfo 6e2cfea [Andrew Or] Remove all return statements in `withScope` d19c4da [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz2 7ef957c [Andrew Or] Fix scala style 4310271 [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz2 aa868a9 [Andrew Or] Ensure that HadoopRDD is actually serializable c3bfcae [Andrew Or] Re-implement scopes using closures instead of annotations 52187fc [Andrew Or] Rat excludes 09d361e [Andrew Or] Add ID to node label (minor) 71281fa [Andrew Or] Embed the viz in the UI in a toggleable manner 8dd5af2 [Andrew Or] Fill in documentation + miscellaneous minor changes fe7816f [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz 205f838 [Andrew Or] Reimplement rendering with dagre-d3 instead of viz.js 5e22946 [Andrew Or] Merge branch 'master' of github.com:apache/spark into viz 6a7cdca [Andrew Or] Move RDD scope util methods and logic to its own file 494d5c2 [Andrew Or] Revert a few unintended style changes 9fac6f3 [Andrew Or] Re-implement scopes through annotations instead f22f337 [Andrew Or] First working implementation of visualization with vis.js 2184348 [Andrew Or] Translate RDD information to dot file 5143523 [Andrew Or] Expose the necessary information in RDDInfo a9ed4f9 [Andrew Or] Add a few missing scopes to certain RDD methods 6b3403b [Andrew Or] Scope all RDD methods --- .rat-excludes | 3 + .../org/apache/spark/ui/static/d3.min.js | 5 + .../apache/spark/ui/static/dagre-d3.min.js | 29 ++ .../spark/ui/static/graphlib-dot.min.js | 4 + .../apache/spark/ui/static/spark-dag-viz.js | 392 ++++++++++++++++++ .../org/apache/spark/ui/static/webui.css | 2 +- .../scala/org/apache/spark/SparkContext.scala | 97 +++-- .../apache/spark/rdd/AsyncRDDActions.scala | 10 +- .../apache/spark/rdd/DoubleRDDFunctions.scala | 38 +- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../spark/rdd/OrderedRDDFunctions.scala | 6 +- .../apache/spark/rdd/PairRDDFunctions.scala | 167 +++++--- .../main/scala/org/apache/spark/rdd/RDD.scala | 341 +++++++++------ .../apache/spark/rdd/RDDOperationScope.scala | 137 ++++++ .../spark/rdd/SequenceFileRDDFunctions.scala | 4 +- .../apache/spark/scheduler/StageInfo.scala | 2 + .../org/apache/spark/storage/RDDInfo.scala | 11 +- .../scala/org/apache/spark/ui/SparkUI.scala | 10 +- .../scala/org/apache/spark/ui/UIUtils.scala | 55 ++- .../apache/spark/ui/jobs/AllJobsPage.scala | 2 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 10 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 2 +- .../org/apache/spark/ui/jobs/JobPage.scala | 14 +- .../spark/ui/jobs/JobProgressListener.scala | 17 +- .../org/apache/spark/ui/jobs/JobsTab.scala | 6 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 4 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 17 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 7 +- .../spark/ui/scope/RDDOperationGraph.scala | 205 +++++++++ .../ui/scope/RDDOperationGraphListener.scala | 68 +++ .../org/apache/spark/util/JsonProtocol.scala | 28 +- .../ExecutorAllocationManagerSuite.scala | 2 +- .../spark/rdd/RDDOperationScopeSuite.scala | 133 ++++++ .../apache/spark/storage/StorageSuite.scala | 4 +- .../ui/jobs/JobProgressListenerSuite.scala | 6 +- .../spark/ui/storage/StorageTabSuite.scala | 30 +- .../apache/spark/util/JsonProtocolSuite.scala | 45 +- 38 files changed, 1584 insertions(+), 337 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/d3.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js create mode 100644 core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala diff --git a/.rat-excludes b/.rat-excludes index 2238a5b68e359..dccf2db8055ce 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -30,6 +30,9 @@ spark-env.sh.template log4j-defaults.properties bootstrap-tooltip.js jquery-1.11.1.min.js +d3.min.js +dagre-d3.min.js +graphlib-dot.min.js sorttable.js vis.min.js vis.min.css diff --git a/core/src/main/resources/org/apache/spark/ui/static/d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/d3.min.js new file mode 100644 index 0000000000000..30cd292198b91 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/d3.min.js @@ -0,0 +1,5 @@ +/*v3.5.5*/!function(){function n(n){return n&&(n.ownerDocument||n.document||n).documentElement}function t(n){return n&&(n.ownerDocument&&n.ownerDocument.defaultView||n.document&&n||n.defaultView)}function e(n,t){return t>n?-1:n>t?1:n>=t?0:0/0}function r(n){return null===n?0/0:+n}function u(n){return!isNaN(n)}function i(n){return{left:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)<0?r=i+1:u=i}return r},right:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)>0?u=i:r=i+1}return r}}}function o(n){return n.length}function a(n){for(var t=1;n*t%1;)t*=10;return t}function c(n,t){for(var e in t)Object.defineProperty(n.prototype,e,{value:t[e],enumerable:!1})}function l(){this._=Object.create(null)}function s(n){return(n+="")===pa||n[0]===va?va+n:n}function f(n){return(n+="")[0]===va?n.slice(1):n}function h(n){return s(n)in this._}function g(n){return(n=s(n))in this._&&delete this._[n]}function p(){var n=[];for(var t in this._)n.push(f(t));return n}function v(){var n=0;for(var t in this._)++n;return n}function d(){for(var n in this._)return!1;return!0}function m(){this._=Object.create(null)}function y(n){return n}function M(n,t,e){return function(){var r=e.apply(t,arguments);return r===t?n:r}}function x(n,t){if(t in n)return t;t=t.charAt(0).toUpperCase()+t.slice(1);for(var e=0,r=da.length;r>e;++e){var u=da[e]+t;if(u in n)return u}}function b(){}function _(){}function w(n){function t(){for(var t,r=e,u=-1,i=r.length;++ue;e++)for(var u,i=n[e],o=0,a=i.length;a>o;o++)(u=i[o])&&t(u,o,e);return n}function Z(n){return ya(n,Sa),n}function V(n){var t,e;return function(r,u,i){var o,a=n[i].update,c=a.length;for(i!=e&&(e=i,t=0),u>=t&&(t=u+1);!(o=a[t])&&++t0&&(n=n.slice(0,a));var l=ka.get(n);return l&&(n=l,c=B),a?t?u:r:t?b:i}function $(n,t){return function(e){var r=ta.event;ta.event=e,t[0]=this.__data__;try{n.apply(this,t)}finally{ta.event=r}}}function B(n,t){var e=$(n,t);return function(n){var t=this,r=n.relatedTarget;r&&(r===t||8&r.compareDocumentPosition(t))||e.call(t,n)}}function W(e){var r=".dragsuppress-"+ ++Aa,u="click"+r,i=ta.select(t(e)).on("touchmove"+r,S).on("dragstart"+r,S).on("selectstart"+r,S);if(null==Ea&&(Ea="onselectstart"in e?!1:x(e.style,"userSelect")),Ea){var o=n(e).style,a=o[Ea];o[Ea]="none"}return function(n){if(i.on(r,null),Ea&&(o[Ea]=a),n){var t=function(){i.on(u,null)};i.on(u,function(){S(),t()},!0),setTimeout(t,0)}}}function J(n,e){e.changedTouches&&(e=e.changedTouches[0]);var r=n.ownerSVGElement||n;if(r.createSVGPoint){var u=r.createSVGPoint();if(0>Na){var i=t(n);if(i.scrollX||i.scrollY){r=ta.select("body").append("svg").style({position:"absolute",top:0,left:0,margin:0,padding:0,border:"none"},"important");var o=r[0][0].getScreenCTM();Na=!(o.f||o.e),r.remove()}}return Na?(u.x=e.pageX,u.y=e.pageY):(u.x=e.clientX,u.y=e.clientY),u=u.matrixTransform(n.getScreenCTM().inverse()),[u.x,u.y]}var a=n.getBoundingClientRect();return[e.clientX-a.left-n.clientLeft,e.clientY-a.top-n.clientTop]}function G(){return ta.event.changedTouches[0].identifier}function K(n){return n>0?1:0>n?-1:0}function Q(n,t,e){return(t[0]-n[0])*(e[1]-n[1])-(t[1]-n[1])*(e[0]-n[0])}function nt(n){return n>1?0:-1>n?qa:Math.acos(n)}function tt(n){return n>1?Ra:-1>n?-Ra:Math.asin(n)}function et(n){return((n=Math.exp(n))-1/n)/2}function rt(n){return((n=Math.exp(n))+1/n)/2}function ut(n){return((n=Math.exp(2*n))-1)/(n+1)}function it(n){return(n=Math.sin(n/2))*n}function ot(){}function at(n,t,e){return this instanceof at?(this.h=+n,this.s=+t,void(this.l=+e)):arguments.length<2?n instanceof at?new at(n.h,n.s,n.l):bt(""+n,_t,at):new at(n,t,e)}function ct(n,t,e){function r(n){return n>360?n-=360:0>n&&(n+=360),60>n?i+(o-i)*n/60:180>n?o:240>n?i+(o-i)*(240-n)/60:i}function u(n){return Math.round(255*r(n))}var i,o;return n=isNaN(n)?0:(n%=360)<0?n+360:n,t=isNaN(t)?0:0>t?0:t>1?1:t,e=0>e?0:e>1?1:e,o=.5>=e?e*(1+t):e+t-e*t,i=2*e-o,new mt(u(n+120),u(n),u(n-120))}function lt(n,t,e){return this instanceof lt?(this.h=+n,this.c=+t,void(this.l=+e)):arguments.length<2?n instanceof lt?new lt(n.h,n.c,n.l):n instanceof ft?gt(n.l,n.a,n.b):gt((n=wt((n=ta.rgb(n)).r,n.g,n.b)).l,n.a,n.b):new lt(n,t,e)}function st(n,t,e){return isNaN(n)&&(n=0),isNaN(t)&&(t=0),new ft(e,Math.cos(n*=Da)*t,Math.sin(n)*t)}function ft(n,t,e){return this instanceof ft?(this.l=+n,this.a=+t,void(this.b=+e)):arguments.length<2?n instanceof ft?new ft(n.l,n.a,n.b):n instanceof lt?st(n.h,n.c,n.l):wt((n=mt(n)).r,n.g,n.b):new ft(n,t,e)}function ht(n,t,e){var r=(n+16)/116,u=r+t/500,i=r-e/200;return u=pt(u)*Xa,r=pt(r)*$a,i=pt(i)*Ba,new mt(dt(3.2404542*u-1.5371385*r-.4985314*i),dt(-.969266*u+1.8760108*r+.041556*i),dt(.0556434*u-.2040259*r+1.0572252*i))}function gt(n,t,e){return n>0?new lt(Math.atan2(e,t)*Pa,Math.sqrt(t*t+e*e),n):new lt(0/0,0/0,n)}function pt(n){return n>.206893034?n*n*n:(n-4/29)/7.787037}function vt(n){return n>.008856?Math.pow(n,1/3):7.787037*n+4/29}function dt(n){return Math.round(255*(.00304>=n?12.92*n:1.055*Math.pow(n,1/2.4)-.055))}function mt(n,t,e){return this instanceof mt?(this.r=~~n,this.g=~~t,void(this.b=~~e)):arguments.length<2?n instanceof mt?new mt(n.r,n.g,n.b):bt(""+n,mt,ct):new mt(n,t,e)}function yt(n){return new mt(n>>16,n>>8&255,255&n)}function Mt(n){return yt(n)+""}function xt(n){return 16>n?"0"+Math.max(0,n).toString(16):Math.min(255,n).toString(16)}function bt(n,t,e){var r,u,i,o=0,a=0,c=0;if(r=/([a-z]+)\((.*)\)/i.exec(n))switch(u=r[2].split(","),r[1]){case"hsl":return e(parseFloat(u[0]),parseFloat(u[1])/100,parseFloat(u[2])/100);case"rgb":return t(kt(u[0]),kt(u[1]),kt(u[2]))}return(i=Ga.get(n.toLowerCase()))?t(i.r,i.g,i.b):(null==n||"#"!==n.charAt(0)||isNaN(i=parseInt(n.slice(1),16))||(4===n.length?(o=(3840&i)>>4,o=o>>4|o,a=240&i,a=a>>4|a,c=15&i,c=c<<4|c):7===n.length&&(o=(16711680&i)>>16,a=(65280&i)>>8,c=255&i)),t(o,a,c))}function _t(n,t,e){var r,u,i=Math.min(n/=255,t/=255,e/=255),o=Math.max(n,t,e),a=o-i,c=(o+i)/2;return a?(u=.5>c?a/(o+i):a/(2-o-i),r=n==o?(t-e)/a+(e>t?6:0):t==o?(e-n)/a+2:(n-t)/a+4,r*=60):(r=0/0,u=c>0&&1>c?0:r),new at(r,u,c)}function wt(n,t,e){n=St(n),t=St(t),e=St(e);var r=vt((.4124564*n+.3575761*t+.1804375*e)/Xa),u=vt((.2126729*n+.7151522*t+.072175*e)/$a),i=vt((.0193339*n+.119192*t+.9503041*e)/Ba);return ft(116*u-16,500*(r-u),200*(u-i))}function St(n){return(n/=255)<=.04045?n/12.92:Math.pow((n+.055)/1.055,2.4)}function kt(n){var t=parseFloat(n);return"%"===n.charAt(n.length-1)?Math.round(2.55*t):t}function Et(n){return"function"==typeof n?n:function(){return n}}function At(n){return function(t,e,r){return 2===arguments.length&&"function"==typeof e&&(r=e,e=null),Nt(t,e,n,r)}}function Nt(n,t,e,r){function u(){var n,t=c.status;if(!t&&zt(c)||t>=200&&300>t||304===t){try{n=e.call(i,c)}catch(r){return void o.error.call(i,r)}o.load.call(i,n)}else o.error.call(i,c)}var i={},o=ta.dispatch("beforesend","progress","load","error"),a={},c=new XMLHttpRequest,l=null;return!this.XDomainRequest||"withCredentials"in c||!/^(http(s)?:)?\/\//.test(n)||(c=new XDomainRequest),"onload"in c?c.onload=c.onerror=u:c.onreadystatechange=function(){c.readyState>3&&u()},c.onprogress=function(n){var t=ta.event;ta.event=n;try{o.progress.call(i,c)}finally{ta.event=t}},i.header=function(n,t){return n=(n+"").toLowerCase(),arguments.length<2?a[n]:(null==t?delete a[n]:a[n]=t+"",i)},i.mimeType=function(n){return arguments.length?(t=null==n?null:n+"",i):t},i.responseType=function(n){return arguments.length?(l=n,i):l},i.response=function(n){return e=n,i},["get","post"].forEach(function(n){i[n]=function(){return i.send.apply(i,[n].concat(ra(arguments)))}}),i.send=function(e,r,u){if(2===arguments.length&&"function"==typeof r&&(u=r,r=null),c.open(e,n,!0),null==t||"accept"in a||(a.accept=t+",*/*"),c.setRequestHeader)for(var s in a)c.setRequestHeader(s,a[s]);return null!=t&&c.overrideMimeType&&c.overrideMimeType(t),null!=l&&(c.responseType=l),null!=u&&i.on("error",u).on("load",function(n){u(null,n)}),o.beforesend.call(i,c),c.send(null==r?null:r),i},i.abort=function(){return c.abort(),i},ta.rebind(i,o,"on"),null==r?i:i.get(Ct(r))}function Ct(n){return 1===n.length?function(t,e){n(null==t?e:null)}:n}function zt(n){var t=n.responseType;return t&&"text"!==t?n.response:n.responseText}function qt(){var n=Lt(),t=Tt()-n;t>24?(isFinite(t)&&(clearTimeout(tc),tc=setTimeout(qt,t)),nc=0):(nc=1,rc(qt))}function Lt(){var n=Date.now();for(ec=Ka;ec;)n>=ec.t&&(ec.f=ec.c(n-ec.t)),ec=ec.n;return n}function Tt(){for(var n,t=Ka,e=1/0;t;)t.f?t=n?n.n=t.n:Ka=t.n:(t.t8?function(n){return n/e}:function(n){return n*e},symbol:n}}function Pt(n){var t=n.decimal,e=n.thousands,r=n.grouping,u=n.currency,i=r&&e?function(n,t){for(var u=n.length,i=[],o=0,a=r[0],c=0;u>0&&a>0&&(c+a+1>t&&(a=Math.max(1,t-c)),i.push(n.substring(u-=a,u+a)),!((c+=a+1)>t));)a=r[o=(o+1)%r.length];return i.reverse().join(e)}:y;return function(n){var e=ic.exec(n),r=e[1]||" ",o=e[2]||">",a=e[3]||"-",c=e[4]||"",l=e[5],s=+e[6],f=e[7],h=e[8],g=e[9],p=1,v="",d="",m=!1,y=!0;switch(h&&(h=+h.substring(1)),(l||"0"===r&&"="===o)&&(l=r="0",o="="),g){case"n":f=!0,g="g";break;case"%":p=100,d="%",g="f";break;case"p":p=100,d="%",g="r";break;case"b":case"o":case"x":case"X":"#"===c&&(v="0"+g.toLowerCase());case"c":y=!1;case"d":m=!0,h=0;break;case"s":p=-1,g="r"}"$"===c&&(v=u[0],d=u[1]),"r"!=g||h||(g="g"),null!=h&&("g"==g?h=Math.max(1,Math.min(21,h)):("e"==g||"f"==g)&&(h=Math.max(0,Math.min(20,h)))),g=oc.get(g)||Ut;var M=l&&f;return function(n){var e=d;if(m&&n%1)return"";var u=0>n||0===n&&0>1/n?(n=-n,"-"):"-"===a?"":a;if(0>p){var c=ta.formatPrefix(n,h);n=c.scale(n),e=c.symbol+d}else n*=p;n=g(n,h);var x,b,_=n.lastIndexOf(".");if(0>_){var w=y?n.lastIndexOf("e"):-1;0>w?(x=n,b=""):(x=n.substring(0,w),b=n.substring(w))}else x=n.substring(0,_),b=t+n.substring(_+1);!l&&f&&(x=i(x,1/0));var S=v.length+x.length+b.length+(M?0:u.length),k=s>S?new Array(S=s-S+1).join(r):"";return M&&(x=i(k+x,k.length?s-b.length:1/0)),u+=v,n=x+b,("<"===o?u+n+k:">"===o?k+u+n:"^"===o?k.substring(0,S>>=1)+u+n+k.substring(S):u+(M?n:k+n))+e}}}function Ut(n){return n+""}function jt(){this._=new Date(arguments.length>1?Date.UTC.apply(this,arguments):arguments[0])}function Ft(n,t,e){function r(t){var e=n(t),r=i(e,1);return r-t>t-e?e:r}function u(e){return t(e=n(new cc(e-1)),1),e}function i(n,e){return t(n=new cc(+n),e),n}function o(n,r,i){var o=u(n),a=[];if(i>1)for(;r>o;)e(o)%i||a.push(new Date(+o)),t(o,1);else for(;r>o;)a.push(new Date(+o)),t(o,1);return a}function a(n,t,e){try{cc=jt;var r=new jt;return r._=n,o(r,t,e)}finally{cc=Date}}n.floor=n,n.round=r,n.ceil=u,n.offset=i,n.range=o;var c=n.utc=Ht(n);return c.floor=c,c.round=Ht(r),c.ceil=Ht(u),c.offset=Ht(i),c.range=a,n}function Ht(n){return function(t,e){try{cc=jt;var r=new jt;return r._=t,n(r,e)._}finally{cc=Date}}}function Ot(n){function t(n){function t(t){for(var e,u,i,o=[],a=-1,c=0;++aa;){if(r>=l)return-1;if(u=t.charCodeAt(a++),37===u){if(o=t.charAt(a++),i=C[o in sc?t.charAt(a++):o],!i||(r=i(n,e,r))<0)return-1}else if(u!=e.charCodeAt(r++))return-1}return r}function r(n,t,e){_.lastIndex=0;var r=_.exec(t.slice(e));return r?(n.w=w.get(r[0].toLowerCase()),e+r[0].length):-1}function u(n,t,e){x.lastIndex=0;var r=x.exec(t.slice(e));return r?(n.w=b.get(r[0].toLowerCase()),e+r[0].length):-1}function i(n,t,e){E.lastIndex=0;var r=E.exec(t.slice(e));return r?(n.m=A.get(r[0].toLowerCase()),e+r[0].length):-1}function o(n,t,e){S.lastIndex=0;var r=S.exec(t.slice(e));return r?(n.m=k.get(r[0].toLowerCase()),e+r[0].length):-1}function a(n,t,r){return e(n,N.c.toString(),t,r)}function c(n,t,r){return e(n,N.x.toString(),t,r)}function l(n,t,r){return e(n,N.X.toString(),t,r)}function s(n,t,e){var r=M.get(t.slice(e,e+=2).toLowerCase());return null==r?-1:(n.p=r,e)}var f=n.dateTime,h=n.date,g=n.time,p=n.periods,v=n.days,d=n.shortDays,m=n.months,y=n.shortMonths;t.utc=function(n){function e(n){try{cc=jt;var t=new cc;return t._=n,r(t)}finally{cc=Date}}var r=t(n);return e.parse=function(n){try{cc=jt;var t=r.parse(n);return t&&t._}finally{cc=Date}},e.toString=r.toString,e},t.multi=t.utc.multi=ae;var M=ta.map(),x=Yt(v),b=Zt(v),_=Yt(d),w=Zt(d),S=Yt(m),k=Zt(m),E=Yt(y),A=Zt(y);p.forEach(function(n,t){M.set(n.toLowerCase(),t)});var N={a:function(n){return d[n.getDay()]},A:function(n){return v[n.getDay()]},b:function(n){return y[n.getMonth()]},B:function(n){return m[n.getMonth()]},c:t(f),d:function(n,t){return It(n.getDate(),t,2)},e:function(n,t){return It(n.getDate(),t,2)},H:function(n,t){return It(n.getHours(),t,2)},I:function(n,t){return It(n.getHours()%12||12,t,2)},j:function(n,t){return It(1+ac.dayOfYear(n),t,3)},L:function(n,t){return It(n.getMilliseconds(),t,3)},m:function(n,t){return It(n.getMonth()+1,t,2)},M:function(n,t){return It(n.getMinutes(),t,2)},p:function(n){return p[+(n.getHours()>=12)]},S:function(n,t){return It(n.getSeconds(),t,2)},U:function(n,t){return It(ac.sundayOfYear(n),t,2)},w:function(n){return n.getDay()},W:function(n,t){return It(ac.mondayOfYear(n),t,2)},x:t(h),X:t(g),y:function(n,t){return It(n.getFullYear()%100,t,2)},Y:function(n,t){return It(n.getFullYear()%1e4,t,4)},Z:ie,"%":function(){return"%"}},C={a:r,A:u,b:i,B:o,c:a,d:Qt,e:Qt,H:te,I:te,j:ne,L:ue,m:Kt,M:ee,p:s,S:re,U:Xt,w:Vt,W:$t,x:c,X:l,y:Wt,Y:Bt,Z:Jt,"%":oe};return t}function It(n,t,e){var r=0>n?"-":"",u=(r?-n:n)+"",i=u.length;return r+(e>i?new Array(e-i+1).join(t)+u:u)}function Yt(n){return new RegExp("^(?:"+n.map(ta.requote).join("|")+")","i")}function Zt(n){for(var t=new l,e=-1,r=n.length;++e68?1900:2e3)}function Kt(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.m=r[0]-1,e+r[0].length):-1}function Qt(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.d=+r[0],e+r[0].length):-1}function ne(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+3));return r?(n.j=+r[0],e+r[0].length):-1}function te(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.H=+r[0],e+r[0].length):-1}function ee(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.M=+r[0],e+r[0].length):-1}function re(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.S=+r[0],e+r[0].length):-1}function ue(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+3));return r?(n.L=+r[0],e+r[0].length):-1}function ie(n){var t=n.getTimezoneOffset(),e=t>0?"-":"+",r=ga(t)/60|0,u=ga(t)%60;return e+It(r,"0",2)+It(u,"0",2)}function oe(n,t,e){hc.lastIndex=0;var r=hc.exec(t.slice(e,e+1));return r?e+r[0].length:-1}function ae(n){for(var t=n.length,e=-1;++e=0?1:-1,a=o*e,c=Math.cos(t),l=Math.sin(t),s=i*l,f=u*c+s*Math.cos(a),h=s*o*Math.sin(a);yc.add(Math.atan2(h,f)),r=n,u=c,i=l}var t,e,r,u,i;Mc.point=function(o,a){Mc.point=n,r=(t=o)*Da,u=Math.cos(a=(e=a)*Da/2+qa/4),i=Math.sin(a)},Mc.lineEnd=function(){n(t,e)}}function pe(n){var t=n[0],e=n[1],r=Math.cos(e);return[r*Math.cos(t),r*Math.sin(t),Math.sin(e)]}function ve(n,t){return n[0]*t[0]+n[1]*t[1]+n[2]*t[2]}function de(n,t){return[n[1]*t[2]-n[2]*t[1],n[2]*t[0]-n[0]*t[2],n[0]*t[1]-n[1]*t[0]]}function me(n,t){n[0]+=t[0],n[1]+=t[1],n[2]+=t[2]}function ye(n,t){return[n[0]*t,n[1]*t,n[2]*t]}function Me(n){var t=Math.sqrt(n[0]*n[0]+n[1]*n[1]+n[2]*n[2]);n[0]/=t,n[1]/=t,n[2]/=t}function xe(n){return[Math.atan2(n[1],n[0]),tt(n[2])]}function be(n,t){return ga(n[0]-t[0])a;++a)u.point((e=n[a])[0],e[1]);return void u.lineEnd()}var c=new qe(e,n,null,!0),l=new qe(e,null,c,!1);c.o=l,i.push(c),o.push(l),c=new qe(r,n,null,!1),l=new qe(r,null,c,!0),c.o=l,i.push(c),o.push(l)}}),o.sort(t),ze(i),ze(o),i.length){for(var a=0,c=e,l=o.length;l>a;++a)o[a].e=c=!c;for(var s,f,h=i[0];;){for(var g=h,p=!0;g.v;)if((g=g.n)===h)return;s=g.z,u.lineStart();do{if(g.v=g.o.v=!0,g.e){if(p)for(var a=0,l=s.length;l>a;++a)u.point((f=s[a])[0],f[1]);else r(g.x,g.n.x,1,u);g=g.n}else{if(p){s=g.p.z;for(var a=s.length-1;a>=0;--a)u.point((f=s[a])[0],f[1])}else r(g.x,g.p.x,-1,u);g=g.p}g=g.o,s=g.z,p=!p}while(!g.v);u.lineEnd()}}}function ze(n){if(t=n.length){for(var t,e,r=0,u=n[0];++r0){for(b||(i.polygonStart(),b=!0),i.lineStart();++o1&&2&t&&e.push(e.pop().concat(e.shift())),g.push(e.filter(Te))}var g,p,v,d=t(i),m=u.invert(r[0],r[1]),y={point:o,lineStart:c,lineEnd:l,polygonStart:function(){y.point=s,y.lineStart=f,y.lineEnd=h,g=[],p=[]},polygonEnd:function(){y.point=o,y.lineStart=c,y.lineEnd=l,g=ta.merge(g);var n=Fe(m,p);g.length?(b||(i.polygonStart(),b=!0),Ce(g,De,n,e,i)):n&&(b||(i.polygonStart(),b=!0),i.lineStart(),e(null,null,1,i),i.lineEnd()),b&&(i.polygonEnd(),b=!1),g=p=null},sphere:function(){i.polygonStart(),i.lineStart(),e(null,null,1,i),i.lineEnd(),i.polygonEnd()}},M=Re(),x=t(M),b=!1;return y}}function Te(n){return n.length>1}function Re(){var n,t=[];return{lineStart:function(){t.push(n=[])},point:function(t,e){n.push([t,e])},lineEnd:b,buffer:function(){var e=t;return t=[],n=null,e},rejoin:function(){t.length>1&&t.push(t.pop().concat(t.shift()))}}}function De(n,t){return((n=n.x)[0]<0?n[1]-Ra-Ca:Ra-n[1])-((t=t.x)[0]<0?t[1]-Ra-Ca:Ra-t[1])}function Pe(n){var t,e=0/0,r=0/0,u=0/0;return{lineStart:function(){n.lineStart(),t=1},point:function(i,o){var a=i>0?qa:-qa,c=ga(i-e);ga(c-qa)0?Ra:-Ra),n.point(u,r),n.lineEnd(),n.lineStart(),n.point(a,r),n.point(i,r),t=0):u!==a&&c>=qa&&(ga(e-u)Ca?Math.atan((Math.sin(t)*(i=Math.cos(r))*Math.sin(e)-Math.sin(r)*(u=Math.cos(t))*Math.sin(n))/(u*i*o)):(t+r)/2}function je(n,t,e,r){var u;if(null==n)u=e*Ra,r.point(-qa,u),r.point(0,u),r.point(qa,u),r.point(qa,0),r.point(qa,-u),r.point(0,-u),r.point(-qa,-u),r.point(-qa,0),r.point(-qa,u);else if(ga(n[0]-t[0])>Ca){var i=n[0]a;++a){var l=t[a],s=l.length;if(s)for(var f=l[0],h=f[0],g=f[1]/2+qa/4,p=Math.sin(g),v=Math.cos(g),d=1;;){d===s&&(d=0),n=l[d];var m=n[0],y=n[1]/2+qa/4,M=Math.sin(y),x=Math.cos(y),b=m-h,_=b>=0?1:-1,w=_*b,S=w>qa,k=p*M;if(yc.add(Math.atan2(k*_*Math.sin(w),v*x+k*Math.cos(w))),i+=S?b+_*La:b,S^h>=e^m>=e){var E=de(pe(f),pe(n));Me(E);var A=de(u,E);Me(A);var N=(S^b>=0?-1:1)*tt(A[2]);(r>N||r===N&&(E[0]||E[1]))&&(o+=S^b>=0?1:-1)}if(!d++)break;h=m,p=M,v=x,f=n}}return(-Ca>i||Ca>i&&0>yc)^1&o}function He(n){function t(n,t){return Math.cos(n)*Math.cos(t)>i}function e(n){var e,i,c,l,s;return{lineStart:function(){l=c=!1,s=1},point:function(f,h){var g,p=[f,h],v=t(f,h),d=o?v?0:u(f,h):v?u(f+(0>f?qa:-qa),h):0;if(!e&&(l=c=v)&&n.lineStart(),v!==c&&(g=r(e,p),(be(e,g)||be(p,g))&&(p[0]+=Ca,p[1]+=Ca,v=t(p[0],p[1]))),v!==c)s=0,v?(n.lineStart(),g=r(p,e),n.point(g[0],g[1])):(g=r(e,p),n.point(g[0],g[1]),n.lineEnd()),e=g;else if(a&&e&&o^v){var m;d&i||!(m=r(p,e,!0))||(s=0,o?(n.lineStart(),n.point(m[0][0],m[0][1]),n.point(m[1][0],m[1][1]),n.lineEnd()):(n.point(m[1][0],m[1][1]),n.lineEnd(),n.lineStart(),n.point(m[0][0],m[0][1])))}!v||e&&be(e,p)||n.point(p[0],p[1]),e=p,c=v,i=d},lineEnd:function(){c&&n.lineEnd(),e=null},clean:function(){return s|(l&&c)<<1}}}function r(n,t,e){var r=pe(n),u=pe(t),o=[1,0,0],a=de(r,u),c=ve(a,a),l=a[0],s=c-l*l;if(!s)return!e&&n;var f=i*c/s,h=-i*l/s,g=de(o,a),p=ye(o,f),v=ye(a,h);me(p,v);var d=g,m=ve(p,d),y=ve(d,d),M=m*m-y*(ve(p,p)-1);if(!(0>M)){var x=Math.sqrt(M),b=ye(d,(-m-x)/y);if(me(b,p),b=xe(b),!e)return b;var _,w=n[0],S=t[0],k=n[1],E=t[1];w>S&&(_=w,w=S,S=_);var A=S-w,N=ga(A-qa)A;if(!N&&k>E&&(_=k,k=E,E=_),C?N?k+E>0^b[1]<(ga(b[0]-w)qa^(w<=b[0]&&b[0]<=S)){var z=ye(d,(-m+x)/y);return me(z,p),[b,xe(z)]}}}function u(t,e){var r=o?n:qa-n,u=0;return-r>t?u|=1:t>r&&(u|=2),-r>e?u|=4:e>r&&(u|=8),u}var i=Math.cos(n),o=i>0,a=ga(i)>Ca,c=gr(n,6*Da);return Le(t,e,c,o?[0,-n]:[-qa,n-qa])}function Oe(n,t,e,r){return function(u){var i,o=u.a,a=u.b,c=o.x,l=o.y,s=a.x,f=a.y,h=0,g=1,p=s-c,v=f-l;if(i=n-c,p||!(i>0)){if(i/=p,0>p){if(h>i)return;g>i&&(g=i)}else if(p>0){if(i>g)return;i>h&&(h=i)}if(i=e-c,p||!(0>i)){if(i/=p,0>p){if(i>g)return;i>h&&(h=i)}else if(p>0){if(h>i)return;g>i&&(g=i)}if(i=t-l,v||!(i>0)){if(i/=v,0>v){if(h>i)return;g>i&&(g=i)}else if(v>0){if(i>g)return;i>h&&(h=i)}if(i=r-l,v||!(0>i)){if(i/=v,0>v){if(i>g)return;i>h&&(h=i)}else if(v>0){if(h>i)return;g>i&&(g=i)}return h>0&&(u.a={x:c+h*p,y:l+h*v}),1>g&&(u.b={x:c+g*p,y:l+g*v}),u}}}}}}function Ie(n,t,e,r){function u(r,u){return ga(r[0]-n)0?0:3:ga(r[0]-e)0?2:1:ga(r[1]-t)0?1:0:u>0?3:2}function i(n,t){return o(n.x,t.x)}function o(n,t){var e=u(n,1),r=u(t,1);return e!==r?e-r:0===e?t[1]-n[1]:1===e?n[0]-t[0]:2===e?n[1]-t[1]:t[0]-n[0]}return function(a){function c(n){for(var t=0,e=d.length,r=n[1],u=0;e>u;++u)for(var i,o=1,a=d[u],c=a.length,l=a[0];c>o;++o)i=a[o],l[1]<=r?i[1]>r&&Q(l,i,n)>0&&++t:i[1]<=r&&Q(l,i,n)<0&&--t,l=i;return 0!==t}function l(i,a,c,l){var s=0,f=0;if(null==i||(s=u(i,c))!==(f=u(a,c))||o(i,a)<0^c>0){do l.point(0===s||3===s?n:e,s>1?r:t);while((s=(s+c+4)%4)!==f)}else l.point(a[0],a[1])}function s(u,i){return u>=n&&e>=u&&i>=t&&r>=i}function f(n,t){s(n,t)&&a.point(n,t)}function h(){C.point=p,d&&d.push(m=[]),S=!0,w=!1,b=_=0/0}function g(){v&&(p(y,M),x&&w&&A.rejoin(),v.push(A.buffer())),C.point=f,w&&a.lineEnd()}function p(n,t){n=Math.max(-Tc,Math.min(Tc,n)),t=Math.max(-Tc,Math.min(Tc,t));var e=s(n,t);if(d&&m.push([n,t]),S)y=n,M=t,x=e,S=!1,e&&(a.lineStart(),a.point(n,t));else if(e&&w)a.point(n,t);else{var r={a:{x:b,y:_},b:{x:n,y:t}};N(r)?(w||(a.lineStart(),a.point(r.a.x,r.a.y)),a.point(r.b.x,r.b.y),e||a.lineEnd(),k=!1):e&&(a.lineStart(),a.point(n,t),k=!1)}b=n,_=t,w=e}var v,d,m,y,M,x,b,_,w,S,k,E=a,A=Re(),N=Oe(n,t,e,r),C={point:f,lineStart:h,lineEnd:g,polygonStart:function(){a=A,v=[],d=[],k=!0},polygonEnd:function(){a=E,v=ta.merge(v);var t=c([n,r]),e=k&&t,u=v.length;(e||u)&&(a.polygonStart(),e&&(a.lineStart(),l(null,null,1,a),a.lineEnd()),u&&Ce(v,i,t,l,a),a.polygonEnd()),v=d=m=null}};return C}}function Ye(n){var t=0,e=qa/3,r=ir(n),u=r(t,e);return u.parallels=function(n){return arguments.length?r(t=n[0]*qa/180,e=n[1]*qa/180):[t/qa*180,e/qa*180]},u}function Ze(n,t){function e(n,t){var e=Math.sqrt(i-2*u*Math.sin(t))/u;return[e*Math.sin(n*=u),o-e*Math.cos(n)]}var r=Math.sin(n),u=(r+Math.sin(t))/2,i=1+r*(2*u-r),o=Math.sqrt(i)/u;return e.invert=function(n,t){var e=o-t;return[Math.atan2(n,e)/u,tt((i-(n*n+e*e)*u*u)/(2*u))]},e}function Ve(){function n(n,t){Dc+=u*n-r*t,r=n,u=t}var t,e,r,u;Hc.point=function(i,o){Hc.point=n,t=r=i,e=u=o},Hc.lineEnd=function(){n(t,e)}}function Xe(n,t){Pc>n&&(Pc=n),n>jc&&(jc=n),Uc>t&&(Uc=t),t>Fc&&(Fc=t)}function $e(){function n(n,t){o.push("M",n,",",t,i)}function t(n,t){o.push("M",n,",",t),a.point=e}function e(n,t){o.push("L",n,",",t)}function r(){a.point=n}function u(){o.push("Z")}var i=Be(4.5),o=[],a={point:n,lineStart:function(){a.point=t},lineEnd:r,polygonStart:function(){a.lineEnd=u},polygonEnd:function(){a.lineEnd=r,a.point=n},pointRadius:function(n){return i=Be(n),a},result:function(){if(o.length){var n=o.join("");return o=[],n}}};return a}function Be(n){return"m0,"+n+"a"+n+","+n+" 0 1,1 0,"+-2*n+"a"+n+","+n+" 0 1,1 0,"+2*n+"z"}function We(n,t){_c+=n,wc+=t,++Sc}function Je(){function n(n,r){var u=n-t,i=r-e,o=Math.sqrt(u*u+i*i);kc+=o*(t+n)/2,Ec+=o*(e+r)/2,Ac+=o,We(t=n,e=r)}var t,e;Ic.point=function(r,u){Ic.point=n,We(t=r,e=u)}}function Ge(){Ic.point=We}function Ke(){function n(n,t){var e=n-r,i=t-u,o=Math.sqrt(e*e+i*i);kc+=o*(r+n)/2,Ec+=o*(u+t)/2,Ac+=o,o=u*n-r*t,Nc+=o*(r+n),Cc+=o*(u+t),zc+=3*o,We(r=n,u=t)}var t,e,r,u;Ic.point=function(i,o){Ic.point=n,We(t=r=i,e=u=o)},Ic.lineEnd=function(){n(t,e)}}function Qe(n){function t(t,e){n.moveTo(t+o,e),n.arc(t,e,o,0,La)}function e(t,e){n.moveTo(t,e),a.point=r}function r(t,e){n.lineTo(t,e)}function u(){a.point=t}function i(){n.closePath()}var o=4.5,a={point:t,lineStart:function(){a.point=e},lineEnd:u,polygonStart:function(){a.lineEnd=i},polygonEnd:function(){a.lineEnd=u,a.point=t},pointRadius:function(n){return o=n,a},result:b};return a}function nr(n){function t(n){return(a?r:e)(n)}function e(t){return rr(t,function(e,r){e=n(e,r),t.point(e[0],e[1])})}function r(t){function e(e,r){e=n(e,r),t.point(e[0],e[1])}function r(){M=0/0,S.point=i,t.lineStart()}function i(e,r){var i=pe([e,r]),o=n(e,r);u(M,x,y,b,_,w,M=o[0],x=o[1],y=e,b=i[0],_=i[1],w=i[2],a,t),t.point(M,x)}function o(){S.point=e,t.lineEnd()}function c(){r(),S.point=l,S.lineEnd=s}function l(n,t){i(f=n,h=t),g=M,p=x,v=b,d=_,m=w,S.point=i}function s(){u(M,x,y,b,_,w,g,p,f,v,d,m,a,t),S.lineEnd=o,o()}var f,h,g,p,v,d,m,y,M,x,b,_,w,S={point:e,lineStart:r,lineEnd:o,polygonStart:function(){t.polygonStart(),S.lineStart=c +},polygonEnd:function(){t.polygonEnd(),S.lineStart=r}};return S}function u(t,e,r,a,c,l,s,f,h,g,p,v,d,m){var y=s-t,M=f-e,x=y*y+M*M;if(x>4*i&&d--){var b=a+g,_=c+p,w=l+v,S=Math.sqrt(b*b+_*_+w*w),k=Math.asin(w/=S),E=ga(ga(w)-1)i||ga((y*z+M*q)/x-.5)>.3||o>a*g+c*p+l*v)&&(u(t,e,r,a,c,l,N,C,E,b/=S,_/=S,w,d,m),m.point(N,C),u(N,C,E,b,_,w,s,f,h,g,p,v,d,m))}}var i=.5,o=Math.cos(30*Da),a=16;return t.precision=function(n){return arguments.length?(a=(i=n*n)>0&&16,t):Math.sqrt(i)},t}function tr(n){var t=nr(function(t,e){return n([t*Pa,e*Pa])});return function(n){return or(t(n))}}function er(n){this.stream=n}function rr(n,t){return{point:t,sphere:function(){n.sphere()},lineStart:function(){n.lineStart()},lineEnd:function(){n.lineEnd()},polygonStart:function(){n.polygonStart()},polygonEnd:function(){n.polygonEnd()}}}function ur(n){return ir(function(){return n})()}function ir(n){function t(n){return n=a(n[0]*Da,n[1]*Da),[n[0]*h+c,l-n[1]*h]}function e(n){return n=a.invert((n[0]-c)/h,(l-n[1])/h),n&&[n[0]*Pa,n[1]*Pa]}function r(){a=Ae(o=lr(m,M,x),i);var n=i(v,d);return c=g-n[0]*h,l=p+n[1]*h,u()}function u(){return s&&(s.valid=!1,s=null),t}var i,o,a,c,l,s,f=nr(function(n,t){return n=i(n,t),[n[0]*h+c,l-n[1]*h]}),h=150,g=480,p=250,v=0,d=0,m=0,M=0,x=0,b=Lc,_=y,w=null,S=null;return t.stream=function(n){return s&&(s.valid=!1),s=or(b(o,f(_(n)))),s.valid=!0,s},t.clipAngle=function(n){return arguments.length?(b=null==n?(w=n,Lc):He((w=+n)*Da),u()):w},t.clipExtent=function(n){return arguments.length?(S=n,_=n?Ie(n[0][0],n[0][1],n[1][0],n[1][1]):y,u()):S},t.scale=function(n){return arguments.length?(h=+n,r()):h},t.translate=function(n){return arguments.length?(g=+n[0],p=+n[1],r()):[g,p]},t.center=function(n){return arguments.length?(v=n[0]%360*Da,d=n[1]%360*Da,r()):[v*Pa,d*Pa]},t.rotate=function(n){return arguments.length?(m=n[0]%360*Da,M=n[1]%360*Da,x=n.length>2?n[2]%360*Da:0,r()):[m*Pa,M*Pa,x*Pa]},ta.rebind(t,f,"precision"),function(){return i=n.apply(this,arguments),t.invert=i.invert&&e,r()}}function or(n){return rr(n,function(t,e){n.point(t*Da,e*Da)})}function ar(n,t){return[n,t]}function cr(n,t){return[n>qa?n-La:-qa>n?n+La:n,t]}function lr(n,t,e){return n?t||e?Ae(fr(n),hr(t,e)):fr(n):t||e?hr(t,e):cr}function sr(n){return function(t,e){return t+=n,[t>qa?t-La:-qa>t?t+La:t,e]}}function fr(n){var t=sr(n);return t.invert=sr(-n),t}function hr(n,t){function e(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,l=Math.sin(t),s=l*r+a*u;return[Math.atan2(c*i-s*o,a*r-l*u),tt(s*i+c*o)]}var r=Math.cos(n),u=Math.sin(n),i=Math.cos(t),o=Math.sin(t);return e.invert=function(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,l=Math.sin(t),s=l*i-c*o;return[Math.atan2(c*i+l*o,a*r+s*u),tt(s*r-a*u)]},e}function gr(n,t){var e=Math.cos(n),r=Math.sin(n);return function(u,i,o,a){var c=o*t;null!=u?(u=pr(e,u),i=pr(e,i),(o>0?i>u:u>i)&&(u+=o*La)):(u=n+o*La,i=n-.5*c);for(var l,s=u;o>0?s>i:i>s;s-=c)a.point((l=xe([e,-r*Math.cos(s),-r*Math.sin(s)]))[0],l[1])}}function pr(n,t){var e=pe(t);e[0]-=n,Me(e);var r=nt(-e[1]);return((-e[2]<0?-r:r)+2*Math.PI-Ca)%(2*Math.PI)}function vr(n,t,e){var r=ta.range(n,t-Ca,e).concat(t);return function(n){return r.map(function(t){return[n,t]})}}function dr(n,t,e){var r=ta.range(n,t-Ca,e).concat(t);return function(n){return r.map(function(t){return[t,n]})}}function mr(n){return n.source}function yr(n){return n.target}function Mr(n,t,e,r){var u=Math.cos(t),i=Math.sin(t),o=Math.cos(r),a=Math.sin(r),c=u*Math.cos(n),l=u*Math.sin(n),s=o*Math.cos(e),f=o*Math.sin(e),h=2*Math.asin(Math.sqrt(it(r-t)+u*o*it(e-n))),g=1/Math.sin(h),p=h?function(n){var t=Math.sin(n*=h)*g,e=Math.sin(h-n)*g,r=e*c+t*s,u=e*l+t*f,o=e*i+t*a;return[Math.atan2(u,r)*Pa,Math.atan2(o,Math.sqrt(r*r+u*u))*Pa]}:function(){return[n*Pa,t*Pa]};return p.distance=h,p}function xr(){function n(n,u){var i=Math.sin(u*=Da),o=Math.cos(u),a=ga((n*=Da)-t),c=Math.cos(a);Yc+=Math.atan2(Math.sqrt((a=o*Math.sin(a))*a+(a=r*i-e*o*c)*a),e*i+r*o*c),t=n,e=i,r=o}var t,e,r;Zc.point=function(u,i){t=u*Da,e=Math.sin(i*=Da),r=Math.cos(i),Zc.point=n},Zc.lineEnd=function(){Zc.point=Zc.lineEnd=b}}function br(n,t){function e(t,e){var r=Math.cos(t),u=Math.cos(e),i=n(r*u);return[i*u*Math.sin(t),i*Math.sin(e)]}return e.invert=function(n,e){var r=Math.sqrt(n*n+e*e),u=t(r),i=Math.sin(u),o=Math.cos(u);return[Math.atan2(n*i,r*o),Math.asin(r&&e*i/r)]},e}function _r(n,t){function e(n,t){o>0?-Ra+Ca>t&&(t=-Ra+Ca):t>Ra-Ca&&(t=Ra-Ca);var e=o/Math.pow(u(t),i);return[e*Math.sin(i*n),o-e*Math.cos(i*n)]}var r=Math.cos(n),u=function(n){return Math.tan(qa/4+n/2)},i=n===t?Math.sin(n):Math.log(r/Math.cos(t))/Math.log(u(t)/u(n)),o=r*Math.pow(u(n),i)/i;return i?(e.invert=function(n,t){var e=o-t,r=K(i)*Math.sqrt(n*n+e*e);return[Math.atan2(n,e)/i,2*Math.atan(Math.pow(o/r,1/i))-Ra]},e):Sr}function wr(n,t){function e(n,t){var e=i-t;return[e*Math.sin(u*n),i-e*Math.cos(u*n)]}var r=Math.cos(n),u=n===t?Math.sin(n):(r-Math.cos(t))/(t-n),i=r/u+n;return ga(u)u;u++){for(;r>1&&Q(n[e[r-2]],n[e[r-1]],n[u])<=0;)--r;e[r++]=u}return e.slice(0,r)}function zr(n,t){return n[0]-t[0]||n[1]-t[1]}function qr(n,t,e){return(e[0]-t[0])*(n[1]-t[1])<(e[1]-t[1])*(n[0]-t[0])}function Lr(n,t,e,r){var u=n[0],i=e[0],o=t[0]-u,a=r[0]-i,c=n[1],l=e[1],s=t[1]-c,f=r[1]-l,h=(a*(c-l)-f*(u-i))/(f*o-a*s);return[u+h*o,c+h*s]}function Tr(n){var t=n[0],e=n[n.length-1];return!(t[0]-e[0]||t[1]-e[1])}function Rr(){tu(this),this.edge=this.site=this.circle=null}function Dr(n){var t=el.pop()||new Rr;return t.site=n,t}function Pr(n){Xr(n),Qc.remove(n),el.push(n),tu(n)}function Ur(n){var t=n.circle,e=t.x,r=t.cy,u={x:e,y:r},i=n.P,o=n.N,a=[n];Pr(n);for(var c=i;c.circle&&ga(e-c.circle.x)s;++s)l=a[s],c=a[s-1],Kr(l.edge,c.site,l.site,u);c=a[0],l=a[f-1],l.edge=Jr(c.site,l.site,null,u),Vr(c),Vr(l)}function jr(n){for(var t,e,r,u,i=n.x,o=n.y,a=Qc._;a;)if(r=Fr(a,o)-i,r>Ca)a=a.L;else{if(u=i-Hr(a,o),!(u>Ca)){r>-Ca?(t=a.P,e=a):u>-Ca?(t=a,e=a.N):t=e=a;break}if(!a.R){t=a;break}a=a.R}var c=Dr(n);if(Qc.insert(t,c),t||e){if(t===e)return Xr(t),e=Dr(t.site),Qc.insert(c,e),c.edge=e.edge=Jr(t.site,c.site),Vr(t),void Vr(e);if(!e)return void(c.edge=Jr(t.site,c.site));Xr(t),Xr(e);var l=t.site,s=l.x,f=l.y,h=n.x-s,g=n.y-f,p=e.site,v=p.x-s,d=p.y-f,m=2*(h*d-g*v),y=h*h+g*g,M=v*v+d*d,x={x:(d*y-g*M)/m+s,y:(h*M-v*y)/m+f};Kr(e.edge,l,p,x),c.edge=Jr(l,n,null,x),e.edge=Jr(n,p,null,x),Vr(t),Vr(e)}}function Fr(n,t){var e=n.site,r=e.x,u=e.y,i=u-t;if(!i)return r;var o=n.P;if(!o)return-1/0;e=o.site;var a=e.x,c=e.y,l=c-t;if(!l)return a;var s=a-r,f=1/i-1/l,h=s/l;return f?(-h+Math.sqrt(h*h-2*f*(s*s/(-2*l)-c+l/2+u-i/2)))/f+r:(r+a)/2}function Hr(n,t){var e=n.N;if(e)return Fr(e,t);var r=n.site;return r.y===t?r.x:1/0}function Or(n){this.site=n,this.edges=[]}function Ir(n){for(var t,e,r,u,i,o,a,c,l,s,f=n[0][0],h=n[1][0],g=n[0][1],p=n[1][1],v=Kc,d=v.length;d--;)if(i=v[d],i&&i.prepare())for(a=i.edges,c=a.length,o=0;c>o;)s=a[o].end(),r=s.x,u=s.y,l=a[++o%c].start(),t=l.x,e=l.y,(ga(r-t)>Ca||ga(u-e)>Ca)&&(a.splice(o,0,new Qr(Gr(i.site,s,ga(r-f)Ca?{x:f,y:ga(t-f)Ca?{x:ga(e-p)Ca?{x:h,y:ga(t-h)Ca?{x:ga(e-g)=-za)){var g=c*c+l*l,p=s*s+f*f,v=(f*g-l*p)/h,d=(c*p-s*g)/h,f=d+a,m=rl.pop()||new Zr;m.arc=n,m.site=u,m.x=v+o,m.y=f+Math.sqrt(v*v+d*d),m.cy=f,n.circle=m;for(var y=null,M=tl._;M;)if(m.yd||d>=a)return;if(h>p){if(i){if(i.y>=l)return}else i={x:d,y:c};e={x:d,y:l}}else{if(i){if(i.yr||r>1)if(h>p){if(i){if(i.y>=l)return}else i={x:(c-u)/r,y:c};e={x:(l-u)/r,y:l}}else{if(i){if(i.yg){if(i){if(i.x>=a)return}else i={x:o,y:r*o+u};e={x:a,y:r*a+u}}else{if(i){if(i.xi||f>o||r>h||u>g)){if(p=n.point){var p,v=t-n.x,d=e-n.y,m=v*v+d*d;if(c>m){var y=Math.sqrt(c=m);r=t-y,u=e-y,i=t+y,o=e+y,a=p}}for(var M=n.nodes,x=.5*(s+h),b=.5*(f+g),_=t>=x,w=e>=b,S=w<<1|_,k=S+4;k>S;++S)if(n=M[3&S])switch(3&S){case 0:l(n,s,f,x,b);break;case 1:l(n,x,f,h,b);break;case 2:l(n,s,b,x,g);break;case 3:l(n,x,b,h,g)}}}(n,r,u,i,o),a}function gu(n,t){n=ta.rgb(n),t=ta.rgb(t);var e=n.r,r=n.g,u=n.b,i=t.r-e,o=t.g-r,a=t.b-u;return function(n){return"#"+xt(Math.round(e+i*n))+xt(Math.round(r+o*n))+xt(Math.round(u+a*n))}}function pu(n,t){var e,r={},u={};for(e in n)e in t?r[e]=mu(n[e],t[e]):u[e]=n[e];for(e in t)e in n||(u[e]=t[e]);return function(n){for(e in r)u[e]=r[e](n);return u}}function vu(n,t){return n=+n,t=+t,function(e){return n*(1-e)+t*e}}function du(n,t){var e,r,u,i=il.lastIndex=ol.lastIndex=0,o=-1,a=[],c=[];for(n+="",t+="";(e=il.exec(n))&&(r=ol.exec(t));)(u=r.index)>i&&(u=t.slice(i,u),a[o]?a[o]+=u:a[++o]=u),(e=e[0])===(r=r[0])?a[o]?a[o]+=r:a[++o]=r:(a[++o]=null,c.push({i:o,x:vu(e,r)})),i=ol.lastIndex;return ir;++r)a[(e=c[r]).i]=e.x(n);return a.join("")})}function mu(n,t){for(var e,r=ta.interpolators.length;--r>=0&&!(e=ta.interpolators[r](n,t)););return e}function yu(n,t){var e,r=[],u=[],i=n.length,o=t.length,a=Math.min(n.length,t.length);for(e=0;a>e;++e)r.push(mu(n[e],t[e]));for(;i>e;++e)u[e]=n[e];for(;o>e;++e)u[e]=t[e];return function(n){for(e=0;a>e;++e)u[e]=r[e](n);return u}}function Mu(n){return function(t){return 0>=t?0:t>=1?1:n(t)}}function xu(n){return function(t){return 1-n(1-t)}}function bu(n){return function(t){return.5*(.5>t?n(2*t):2-n(2-2*t))}}function _u(n){return n*n}function wu(n){return n*n*n}function Su(n){if(0>=n)return 0;if(n>=1)return 1;var t=n*n,e=t*n;return 4*(.5>n?e:3*(n-t)+e-.75)}function ku(n){return function(t){return Math.pow(t,n)}}function Eu(n){return 1-Math.cos(n*Ra)}function Au(n){return Math.pow(2,10*(n-1))}function Nu(n){return 1-Math.sqrt(1-n*n)}function Cu(n,t){var e;return arguments.length<2&&(t=.45),arguments.length?e=t/La*Math.asin(1/n):(n=1,e=t/4),function(r){return 1+n*Math.pow(2,-10*r)*Math.sin((r-e)*La/t)}}function zu(n){return n||(n=1.70158),function(t){return t*t*((n+1)*t-n)}}function qu(n){return 1/2.75>n?7.5625*n*n:2/2.75>n?7.5625*(n-=1.5/2.75)*n+.75:2.5/2.75>n?7.5625*(n-=2.25/2.75)*n+.9375:7.5625*(n-=2.625/2.75)*n+.984375}function Lu(n,t){n=ta.hcl(n),t=ta.hcl(t);var e=n.h,r=n.c,u=n.l,i=t.h-e,o=t.c-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.c:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return st(e+i*n,r+o*n,u+a*n)+""}}function Tu(n,t){n=ta.hsl(n),t=ta.hsl(t);var e=n.h,r=n.s,u=n.l,i=t.h-e,o=t.s-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.s:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return ct(e+i*n,r+o*n,u+a*n)+""}}function Ru(n,t){n=ta.lab(n),t=ta.lab(t);var e=n.l,r=n.a,u=n.b,i=t.l-e,o=t.a-r,a=t.b-u;return function(n){return ht(e+i*n,r+o*n,u+a*n)+""}}function Du(n,t){return t-=n,function(e){return Math.round(n+t*e)}}function Pu(n){var t=[n.a,n.b],e=[n.c,n.d],r=ju(t),u=Uu(t,e),i=ju(Fu(e,t,-u))||0;t[0]*e[1]180?s+=360:s-l>180&&(l+=360),u.push({i:r.push(r.pop()+"rotate(",null,")")-2,x:vu(l,s)})):s&&r.push(r.pop()+"rotate("+s+")"),f!=h?u.push({i:r.push(r.pop()+"skewX(",null,")")-2,x:vu(f,h)}):h&&r.push(r.pop()+"skewX("+h+")"),g[0]!=p[0]||g[1]!=p[1]?(e=r.push(r.pop()+"scale(",null,",",null,")"),u.push({i:e-4,x:vu(g[0],p[0])},{i:e-2,x:vu(g[1],p[1])})):(1!=p[0]||1!=p[1])&&r.push(r.pop()+"scale("+p+")"),e=u.length,function(n){for(var t,i=-1;++i=0;)e.push(u[r])}function Qu(n,t){for(var e=[n],r=[];null!=(n=e.pop());)if(r.push(n),(i=n.children)&&(u=i.length))for(var u,i,o=-1;++oe;++e)(t=n[e][1])>u&&(r=e,u=t);return r}function si(n){return n.reduce(fi,0)}function fi(n,t){return n+t[1]}function hi(n,t){return gi(n,Math.ceil(Math.log(t.length)/Math.LN2+1))}function gi(n,t){for(var e=-1,r=+n[0],u=(n[1]-r)/t,i=[];++e<=t;)i[e]=u*e+r;return i}function pi(n){return[ta.min(n),ta.max(n)]}function vi(n,t){return n.value-t.value}function di(n,t){var e=n._pack_next;n._pack_next=t,t._pack_prev=n,t._pack_next=e,e._pack_prev=t}function mi(n,t){n._pack_next=t,t._pack_prev=n}function yi(n,t){var e=t.x-n.x,r=t.y-n.y,u=n.r+t.r;return.999*u*u>e*e+r*r}function Mi(n){function t(n){s=Math.min(n.x-n.r,s),f=Math.max(n.x+n.r,f),h=Math.min(n.y-n.r,h),g=Math.max(n.y+n.r,g)}if((e=n.children)&&(l=e.length)){var e,r,u,i,o,a,c,l,s=1/0,f=-1/0,h=1/0,g=-1/0;if(e.forEach(xi),r=e[0],r.x=-r.r,r.y=0,t(r),l>1&&(u=e[1],u.x=u.r,u.y=0,t(u),l>2))for(i=e[2],wi(r,u,i),t(i),di(r,i),r._pack_prev=i,di(i,u),u=r._pack_next,o=3;l>o;o++){wi(r,u,i=e[o]);var p=0,v=1,d=1;for(a=u._pack_next;a!==u;a=a._pack_next,v++)if(yi(a,i)){p=1;break}if(1==p)for(c=r._pack_prev;c!==a._pack_prev&&!yi(c,i);c=c._pack_prev,d++);p?(d>v||v==d&&u.ro;o++)i=e[o],i.x-=m,i.y-=y,M=Math.max(M,i.r+Math.sqrt(i.x*i.x+i.y*i.y));n.r=M,e.forEach(bi)}}function xi(n){n._pack_next=n._pack_prev=n}function bi(n){delete n._pack_next,delete n._pack_prev}function _i(n,t,e,r){var u=n.children;if(n.x=t+=r*n.x,n.y=e+=r*n.y,n.r*=r,u)for(var i=-1,o=u.length;++i=0;)t=u[i],t.z+=e,t.m+=e,e+=t.s+(r+=t.c)}function Ci(n,t,e){return n.a.parent===t.parent?n.a:e}function zi(n){return 1+ta.max(n,function(n){return n.y})}function qi(n){return n.reduce(function(n,t){return n+t.x},0)/n.length}function Li(n){var t=n.children;return t&&t.length?Li(t[0]):n}function Ti(n){var t,e=n.children;return e&&(t=e.length)?Ti(e[t-1]):n}function Ri(n){return{x:n.x,y:n.y,dx:n.dx,dy:n.dy}}function Di(n,t){var e=n.x+t[3],r=n.y+t[0],u=n.dx-t[1]-t[3],i=n.dy-t[0]-t[2];return 0>u&&(e+=u/2,u=0),0>i&&(r+=i/2,i=0),{x:e,y:r,dx:u,dy:i}}function Pi(n){var t=n[0],e=n[n.length-1];return e>t?[t,e]:[e,t]}function Ui(n){return n.rangeExtent?n.rangeExtent():Pi(n.range())}function ji(n,t,e,r){var u=e(n[0],n[1]),i=r(t[0],t[1]);return function(n){return i(u(n))}}function Fi(n,t){var e,r=0,u=n.length-1,i=n[r],o=n[u];return i>o&&(e=r,r=u,u=e,e=i,i=o,o=e),n[r]=t.floor(i),n[u]=t.ceil(o),n}function Hi(n){return n?{floor:function(t){return Math.floor(t/n)*n},ceil:function(t){return Math.ceil(t/n)*n}}:ml}function Oi(n,t,e,r){var u=[],i=[],o=0,a=Math.min(n.length,t.length)-1;for(n[a]2?Oi:ji,c=r?Iu:Ou;return o=u(n,t,c,e),a=u(t,n,c,mu),i}function i(n){return o(n)}var o,a;return i.invert=function(n){return a(n)},i.domain=function(t){return arguments.length?(n=t.map(Number),u()):n},i.range=function(n){return arguments.length?(t=n,u()):t},i.rangeRound=function(n){return i.range(n).interpolate(Du)},i.clamp=function(n){return arguments.length?(r=n,u()):r},i.interpolate=function(n){return arguments.length?(e=n,u()):e},i.ticks=function(t){return Xi(n,t)},i.tickFormat=function(t,e){return $i(n,t,e)},i.nice=function(t){return Zi(n,t),u()},i.copy=function(){return Ii(n,t,e,r)},u()}function Yi(n,t){return ta.rebind(n,t,"range","rangeRound","interpolate","clamp")}function Zi(n,t){return Fi(n,Hi(Vi(n,t)[2]))}function Vi(n,t){null==t&&(t=10);var e=Pi(n),r=e[1]-e[0],u=Math.pow(10,Math.floor(Math.log(r/t)/Math.LN10)),i=t/r*u;return.15>=i?u*=10:.35>=i?u*=5:.75>=i&&(u*=2),e[0]=Math.ceil(e[0]/u)*u,e[1]=Math.floor(e[1]/u)*u+.5*u,e[2]=u,e}function Xi(n,t){return ta.range.apply(ta,Vi(n,t))}function $i(n,t,e){var r=Vi(n,t);if(e){var u=ic.exec(e);if(u.shift(),"s"===u[8]){var i=ta.formatPrefix(Math.max(ga(r[0]),ga(r[1])));return u[7]||(u[7]="."+Bi(i.scale(r[2]))),u[8]="f",e=ta.format(u.join("")),function(n){return e(i.scale(n))+i.symbol}}u[7]||(u[7]="."+Wi(u[8],r)),e=u.join("")}else e=",."+Bi(r[2])+"f";return ta.format(e)}function Bi(n){return-Math.floor(Math.log(n)/Math.LN10+.01)}function Wi(n,t){var e=Bi(t[2]);return n in yl?Math.abs(e-Bi(Math.max(ga(t[0]),ga(t[1]))))+ +("e"!==n):e-2*("%"===n)}function Ji(n,t,e,r){function u(n){return(e?Math.log(0>n?0:n):-Math.log(n>0?0:-n))/Math.log(t)}function i(n){return e?Math.pow(t,n):-Math.pow(t,-n)}function o(t){return n(u(t))}return o.invert=function(t){return i(n.invert(t))},o.domain=function(t){return arguments.length?(e=t[0]>=0,n.domain((r=t.map(Number)).map(u)),o):r},o.base=function(e){return arguments.length?(t=+e,n.domain(r.map(u)),o):t},o.nice=function(){var t=Fi(r.map(u),e?Math:xl);return n.domain(t),r=t.map(i),o},o.ticks=function(){var n=Pi(r),o=[],a=n[0],c=n[1],l=Math.floor(u(a)),s=Math.ceil(u(c)),f=t%1?2:t;if(isFinite(s-l)){if(e){for(;s>l;l++)for(var h=1;f>h;h++)o.push(i(l)*h);o.push(i(l))}else for(o.push(i(l));l++0;h--)o.push(i(l)*h);for(l=0;o[l]c;s--);o=o.slice(l,s)}return o},o.tickFormat=function(n,t){if(!arguments.length)return Ml;arguments.length<2?t=Ml:"function"!=typeof t&&(t=ta.format(t));var r,a=Math.max(.1,n/o.ticks().length),c=e?(r=1e-12,Math.ceil):(r=-1e-12,Math.floor);return function(n){return n/i(c(u(n)+r))<=a?t(n):""}},o.copy=function(){return Ji(n.copy(),t,e,r)},Yi(o,n)}function Gi(n,t,e){function r(t){return n(u(t))}var u=Ki(t),i=Ki(1/t);return r.invert=function(t){return i(n.invert(t))},r.domain=function(t){return arguments.length?(n.domain((e=t.map(Number)).map(u)),r):e},r.ticks=function(n){return Xi(e,n)},r.tickFormat=function(n,t){return $i(e,n,t)},r.nice=function(n){return r.domain(Zi(e,n))},r.exponent=function(o){return arguments.length?(u=Ki(t=o),i=Ki(1/t),n.domain(e.map(u)),r):t},r.copy=function(){return Gi(n.copy(),t,e)},Yi(r,n)}function Ki(n){return function(t){return 0>t?-Math.pow(-t,n):Math.pow(t,n)}}function Qi(n,t){function e(e){return i[((u.get(e)||("range"===t.t?u.set(e,n.push(e)):0/0))-1)%i.length]}function r(t,e){return ta.range(n.length).map(function(n){return t+e*n})}var u,i,o;return e.domain=function(r){if(!arguments.length)return n;n=[],u=new l;for(var i,o=-1,a=r.length;++oe?[0/0,0/0]:[e>0?a[e-1]:n[0],et?0/0:t/i+n,[t,t+1/i]},r.copy=function(){return to(n,t,e)},u()}function eo(n,t){function e(e){return e>=e?t[ta.bisect(n,e)]:void 0}return e.domain=function(t){return arguments.length?(n=t,e):n},e.range=function(n){return arguments.length?(t=n,e):t},e.invertExtent=function(e){return e=t.indexOf(e),[n[e-1],n[e]]},e.copy=function(){return eo(n,t)},e}function ro(n){function t(n){return+n}return t.invert=t,t.domain=t.range=function(e){return arguments.length?(n=e.map(t),t):n},t.ticks=function(t){return Xi(n,t)},t.tickFormat=function(t,e){return $i(n,t,e)},t.copy=function(){return ro(n)},t}function uo(){return 0}function io(n){return n.innerRadius}function oo(n){return n.outerRadius}function ao(n){return n.startAngle}function co(n){return n.endAngle}function lo(n){return n&&n.padAngle}function so(n,t,e,r){return(n-e)*t-(t-r)*n>0?0:1}function fo(n,t,e,r,u){var i=n[0]-t[0],o=n[1]-t[1],a=(u?r:-r)/Math.sqrt(i*i+o*o),c=a*o,l=-a*i,s=n[0]+c,f=n[1]+l,h=t[0]+c,g=t[1]+l,p=(s+h)/2,v=(f+g)/2,d=h-s,m=g-f,y=d*d+m*m,M=e-r,x=s*g-h*f,b=(0>m?-1:1)*Math.sqrt(M*M*y-x*x),_=(x*m-d*b)/y,w=(-x*d-m*b)/y,S=(x*m+d*b)/y,k=(-x*d+m*b)/y,E=_-p,A=w-v,N=S-p,C=k-v;return E*E+A*A>N*N+C*C&&(_=S,w=k),[[_-c,w-l],[_*e/M,w*e/M]]}function ho(n){function t(t){function o(){l.push("M",i(n(s),a))}for(var c,l=[],s=[],f=-1,h=t.length,g=Et(e),p=Et(r);++f1&&u.push("H",r[0]),u.join("")}function mo(n){for(var t=0,e=n.length,r=n[0],u=[r[0],",",r[1]];++t1){a=t[1],i=n[c],c++,r+="C"+(u[0]+o[0])+","+(u[1]+o[1])+","+(i[0]-a[0])+","+(i[1]-a[1])+","+i[0]+","+i[1];for(var l=2;l9&&(u=3*t/Math.sqrt(u),o[a]=u*e,o[a+1]=u*r));for(a=-1;++a<=c;)u=(n[Math.min(c,a+1)][0]-n[Math.max(0,a-1)][0])/(6*(1+o[a]*o[a])),i.push([u||0,o[a]*u||0]);return i}function To(n){return n.length<3?go(n):n[0]+_o(n,Lo(n))}function Ro(n){for(var t,e,r,u=-1,i=n.length;++ur)return s();var u=i[i.active];u&&(--i.count,delete i[i.active],u.event&&u.event.interrupt.call(n,n.__data__,u.index)),i.active=r,o.event&&o.event.start.call(n,n.__data__,t),o.tween.forEach(function(e,r){(r=r.call(n,n.__data__,t))&&v.push(r)}),h=o.ease,f=o.duration,ta.timer(function(){return p.c=l(e||1)?Ne:l,1},0,a)}function l(e){if(i.active!==r)return 1;for(var u=e/f,a=h(u),c=v.length;c>0;)v[--c].call(n,a);return u>=1?(o.event&&o.event.end.call(n,n.__data__,t),s()):void 0}function s(){return--i.count?delete i[r]:delete n[e],1}var f,h,g=o.delay,p=ec,v=[];return p.t=g+a,u>=g?c(u-g):void(p.c=c)},0,a)}}function Bo(n,t,e){n.attr("transform",function(n){var r=t(n);return"translate("+(isFinite(r)?r:e(n))+",0)"})}function Wo(n,t,e){n.attr("transform",function(n){var r=t(n);return"translate(0,"+(isFinite(r)?r:e(n))+")"})}function Jo(n){return n.toISOString()}function Go(n,t,e){function r(t){return n(t)}function u(n,e){var r=n[1]-n[0],u=r/e,i=ta.bisect(Vl,u);return i==Vl.length?[t.year,Vi(n.map(function(n){return n/31536e6}),e)[2]]:i?t[u/Vl[i-1]1?{floor:function(t){for(;e(t=n.floor(t));)t=Ko(t-1);return t},ceil:function(t){for(;e(t=n.ceil(t));)t=Ko(+t+1);return t}}:n))},r.ticks=function(n,t){var e=Pi(r.domain()),i=null==n?u(e,10):"number"==typeof n?u(e,n):!n.range&&[{range:n},t];return i&&(n=i[0],t=i[1]),n.range(e[0],Ko(+e[1]+1),1>t?1:t)},r.tickFormat=function(){return e},r.copy=function(){return Go(n.copy(),t,e)},Yi(r,n)}function Ko(n){return new Date(n)}function Qo(n){return JSON.parse(n.responseText)}function na(n){var t=ua.createRange();return t.selectNode(ua.body),t.createContextualFragment(n.responseText)}var ta={version:"3.5.5"},ea=[].slice,ra=function(n){return ea.call(n)},ua=this.document;if(ua)try{ra(ua.documentElement.childNodes)[0].nodeType}catch(ia){ra=function(n){for(var t=n.length,e=new Array(t);t--;)e[t]=n[t];return e}}if(Date.now||(Date.now=function(){return+new Date}),ua)try{ua.createElement("DIV").style.setProperty("opacity",0,"")}catch(oa){var aa=this.Element.prototype,ca=aa.setAttribute,la=aa.setAttributeNS,sa=this.CSSStyleDeclaration.prototype,fa=sa.setProperty;aa.setAttribute=function(n,t){ca.call(this,n,t+"")},aa.setAttributeNS=function(n,t,e){la.call(this,n,t,e+"")},sa.setProperty=function(n,t,e){fa.call(this,n,t+"",e)}}ta.ascending=e,ta.descending=function(n,t){return n>t?-1:t>n?1:t>=n?0:0/0},ta.min=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=r){e=r;break}for(;++ur&&(e=r)}else{for(;++u=r){e=r;break}for(;++ur&&(e=r)}return e},ta.max=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=r){e=r;break}for(;++ue&&(e=r)}else{for(;++u=r){e=r;break}for(;++ue&&(e=r)}return e},ta.extent=function(n,t){var e,r,u,i=-1,o=n.length;if(1===arguments.length){for(;++i=r){e=u=r;break}for(;++ir&&(e=r),r>u&&(u=r))}else{for(;++i=r){e=u=r;break}for(;++ir&&(e=r),r>u&&(u=r))}return[e,u]},ta.sum=function(n,t){var e,r=0,i=n.length,o=-1;if(1===arguments.length)for(;++o1?c/(s-1):void 0},ta.deviation=function(){var n=ta.variance.apply(this,arguments);return n?Math.sqrt(n):n};var ha=i(e);ta.bisectLeft=ha.left,ta.bisect=ta.bisectRight=ha.right,ta.bisector=function(n){return i(1===n.length?function(t,r){return e(n(t),r)}:n)},ta.shuffle=function(n,t,e){(i=arguments.length)<3&&(e=n.length,2>i&&(t=0));for(var r,u,i=e-t;i;)u=Math.random()*i--|0,r=n[i+t],n[i+t]=n[u+t],n[u+t]=r;return n},ta.permute=function(n,t){for(var e=t.length,r=new Array(e);e--;)r[e]=n[t[e]];return r},ta.pairs=function(n){for(var t,e=0,r=n.length-1,u=n[0],i=new Array(0>r?0:r);r>e;)i[e]=[t=u,u=n[++e]];return i},ta.zip=function(){if(!(r=arguments.length))return[];for(var n=-1,t=ta.min(arguments,o),e=new Array(t);++n=0;)for(r=n[u],t=r.length;--t>=0;)e[--o]=r[t];return e};var ga=Math.abs;ta.range=function(n,t,e){if(arguments.length<3&&(e=1,arguments.length<2&&(t=n,n=0)),(t-n)/e===1/0)throw new Error("infinite range");var r,u=[],i=a(ga(e)),o=-1;if(n*=i,t*=i,e*=i,0>e)for(;(r=n+e*++o)>t;)u.push(r/i);else for(;(r=n+e*++o)=i.length)return r?r.call(u,o):e?o.sort(e):o;for(var c,s,f,h,g=-1,p=o.length,v=i[a++],d=new l;++g=i.length)return n;var r=[],u=o[e++];return n.forEach(function(n,u){r.push({key:n,values:t(u,e)})}),u?r.sort(function(n,t){return u(n.key,t.key)}):r}var e,r,u={},i=[],o=[];return u.map=function(t,e){return n(e,t,0)},u.entries=function(e){return t(n(ta.map,e,0),0)},u.key=function(n){return i.push(n),u},u.sortKeys=function(n){return o[i.length-1]=n,u},u.sortValues=function(n){return e=n,u},u.rollup=function(n){return r=n,u},u},ta.set=function(n){var t=new m;if(n)for(var e=0,r=n.length;r>e;++e)t.add(n[e]);return t},c(m,{has:h,add:function(n){return this._[s(n+="")]=!0,n},remove:g,values:p,size:v,empty:d,forEach:function(n){for(var t in this._)n.call(this,f(t))}}),ta.behavior={},ta.rebind=function(n,t){for(var e,r=1,u=arguments.length;++r=0&&(r=n.slice(e+1),n=n.slice(0,e)),n)return arguments.length<2?this[n].on(r):this[n].on(r,t);if(2===arguments.length){if(null==t)for(n in this)this.hasOwnProperty(n)&&this[n].on(r,null);return this}},ta.event=null,ta.requote=function(n){return n.replace(ma,"\\$&")};var ma=/[\\\^\$\*\+\?\|\[\]\(\)\.\{\}]/g,ya={}.__proto__?function(n,t){n.__proto__=t}:function(n,t){for(var e in t)n[e]=t[e]},Ma=function(n,t){return t.querySelector(n)},xa=function(n,t){return t.querySelectorAll(n)},ba=function(n,t){var e=n.matches||n[x(n,"matchesSelector")];return(ba=function(n,t){return e.call(n,t)})(n,t)};"function"==typeof Sizzle&&(Ma=function(n,t){return Sizzle(n,t)[0]||null},xa=Sizzle,ba=Sizzle.matchesSelector),ta.selection=function(){return ta.select(ua.documentElement)};var _a=ta.selection.prototype=[];_a.select=function(n){var t,e,r,u,i=[];n=N(n);for(var o=-1,a=this.length;++o=0&&(e=n.slice(0,t),n=n.slice(t+1)),wa.hasOwnProperty(e)?{space:wa[e],local:n}:n}},_a.attr=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node();return n=ta.ns.qualify(n),n.local?e.getAttributeNS(n.space,n.local):e.getAttribute(n)}for(t in n)this.each(z(t,n[t]));return this}return this.each(z(n,t))},_a.classed=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node(),r=(n=T(n)).length,u=-1;if(t=e.classList){for(;++uu){if("string"!=typeof n){2>u&&(e="");for(r in n)this.each(P(r,n[r],e));return this}if(2>u){var i=this.node();return t(i).getComputedStyle(i,null).getPropertyValue(n)}r=""}return this.each(P(n,e,r))},_a.property=function(n,t){if(arguments.length<2){if("string"==typeof n)return this.node()[n];for(t in n)this.each(U(t,n[t]));return this}return this.each(U(n,t))},_a.text=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.textContent=null==t?"":t}:null==n?function(){this.textContent=""}:function(){this.textContent=n}):this.node().textContent},_a.html=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.innerHTML=null==t?"":t}:null==n?function(){this.innerHTML=""}:function(){this.innerHTML=n}):this.node().innerHTML},_a.append=function(n){return n=j(n),this.select(function(){return this.appendChild(n.apply(this,arguments))})},_a.insert=function(n,t){return n=j(n),t=N(t),this.select(function(){return this.insertBefore(n.apply(this,arguments),t.apply(this,arguments)||null)})},_a.remove=function(){return this.each(F)},_a.data=function(n,t){function e(n,e){var r,u,i,o=n.length,f=e.length,h=Math.min(o,f),g=new Array(f),p=new Array(f),v=new Array(o);if(t){var d,m=new l,y=new Array(o);for(r=-1;++rr;++r)p[r]=H(e[r]);for(;o>r;++r)v[r]=n[r]}p.update=g,p.parentNode=g.parentNode=v.parentNode=n.parentNode,a.push(p),c.push(g),s.push(v)}var r,u,i=-1,o=this.length;if(!arguments.length){for(n=new Array(o=(r=this[0]).length);++ii;i++){u.push(t=[]),t.parentNode=(e=this[i]).parentNode;for(var a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return A(u)},_a.order=function(){for(var n=-1,t=this.length;++n=0;)(e=r[u])&&(i&&i!==e.nextSibling&&i.parentNode.insertBefore(e,i),i=e);return this},_a.sort=function(n){n=I.apply(this,arguments);for(var t=-1,e=this.length;++tn;n++)for(var e=this[n],r=0,u=e.length;u>r;r++){var i=e[r];if(i)return i}return null},_a.size=function(){var n=0;return Y(this,function(){++n}),n};var Sa=[];ta.selection.enter=Z,ta.selection.enter.prototype=Sa,Sa.append=_a.append,Sa.empty=_a.empty,Sa.node=_a.node,Sa.call=_a.call,Sa.size=_a.size,Sa.select=function(n){for(var t,e,r,u,i,o=[],a=-1,c=this.length;++ar){if("string"!=typeof n){2>r&&(t=!1);for(e in n)this.each(X(e,n[e],t));return this}if(2>r)return(r=this.node()["__on"+n])&&r._;e=!1}return this.each(X(n,t,e))};var ka=ta.map({mouseenter:"mouseover",mouseleave:"mouseout"});ua&&ka.forEach(function(n){"on"+n in ua&&ka.remove(n)});var Ea,Aa=0;ta.mouse=function(n){return J(n,k())};var Na=this.navigator&&/WebKit/.test(this.navigator.userAgent)?-1:0;ta.touch=function(n,t,e){if(arguments.length<3&&(e=t,t=k().changedTouches),t)for(var r,u=0,i=t.length;i>u;++u)if((r=t[u]).identifier===e)return J(n,r)},ta.behavior.drag=function(){function n(){this.on("mousedown.drag",i).on("touchstart.drag",o)}function e(n,t,e,i,o){return function(){function a(){var n,e,r=t(h,v);r&&(n=r[0]-M[0],e=r[1]-M[1],p|=n|e,M=r,g({type:"drag",x:r[0]+l[0],y:r[1]+l[1],dx:n,dy:e}))}function c(){t(h,v)&&(m.on(i+d,null).on(o+d,null),y(p&&ta.event.target===f),g({type:"dragend"}))}var l,s=this,f=ta.event.target,h=s.parentNode,g=r.of(s,arguments),p=0,v=n(),d=".drag"+(null==v?"":"-"+v),m=ta.select(e(f)).on(i+d,a).on(o+d,c),y=W(f),M=t(h,v);u?(l=u.apply(s,arguments),l=[l.x-M[0],l.y-M[1]]):l=[0,0],g({type:"dragstart"})}}var r=E(n,"drag","dragstart","dragend"),u=null,i=e(b,ta.mouse,t,"mousemove","mouseup"),o=e(G,ta.touch,y,"touchmove","touchend");return n.origin=function(t){return arguments.length?(u=t,n):u},ta.rebind(n,r,"on")},ta.touches=function(n,t){return arguments.length<2&&(t=k().touches),t?ra(t).map(function(t){var e=J(n,t);return e.identifier=t.identifier,e}):[]};var Ca=1e-6,za=Ca*Ca,qa=Math.PI,La=2*qa,Ta=La-Ca,Ra=qa/2,Da=qa/180,Pa=180/qa,Ua=Math.SQRT2,ja=2,Fa=4;ta.interpolateZoom=function(n,t){function e(n){var t=n*y;if(m){var e=rt(v),o=i/(ja*h)*(e*ut(Ua*t+v)-et(v));return[r+o*l,u+o*s,i*e/rt(Ua*t+v)]}return[r+n*l,u+n*s,i*Math.exp(Ua*t)]}var r=n[0],u=n[1],i=n[2],o=t[0],a=t[1],c=t[2],l=o-r,s=a-u,f=l*l+s*s,h=Math.sqrt(f),g=(c*c-i*i+Fa*f)/(2*i*ja*h),p=(c*c-i*i-Fa*f)/(2*c*ja*h),v=Math.log(Math.sqrt(g*g+1)-g),d=Math.log(Math.sqrt(p*p+1)-p),m=d-v,y=(m||Math.log(c/i))/Ua;return e.duration=1e3*y,e},ta.behavior.zoom=function(){function n(n){n.on(q,f).on(Oa+".zoom",g).on("dblclick.zoom",p).on(R,h)}function e(n){return[(n[0]-k.x)/k.k,(n[1]-k.y)/k.k]}function r(n){return[n[0]*k.k+k.x,n[1]*k.k+k.y]}function u(n){k.k=Math.max(N[0],Math.min(N[1],n))}function i(n,t){t=r(t),k.x+=n[0]-t[0],k.y+=n[1]-t[1]}function o(t,e,r,o){t.__chart__={x:k.x,y:k.y,k:k.k},u(Math.pow(2,o)),i(d=e,r),t=ta.select(t),C>0&&(t=t.transition().duration(C)),t.call(n.event)}function a(){b&&b.domain(x.range().map(function(n){return(n-k.x)/k.k}).map(x.invert)),w&&w.domain(_.range().map(function(n){return(n-k.y)/k.k}).map(_.invert))}function c(n){z++||n({type:"zoomstart"})}function l(n){a(),n({type:"zoom",scale:k.k,translate:[k.x,k.y]})}function s(n){--z||n({type:"zoomend"}),d=null}function f(){function n(){f=1,i(ta.mouse(u),g),l(a)}function r(){h.on(L,null).on(T,null),p(f&&ta.event.target===o),s(a)}var u=this,o=ta.event.target,a=D.of(u,arguments),f=0,h=ta.select(t(u)).on(L,n).on(T,r),g=e(ta.mouse(u)),p=W(u);Dl.call(u),c(a)}function h(){function n(){var n=ta.touches(p);return g=k.k,n.forEach(function(n){n.identifier in d&&(d[n.identifier]=e(n))}),n}function t(){var t=ta.event.target;ta.select(t).on(x,r).on(b,a),_.push(t);for(var e=ta.event.changedTouches,u=0,i=e.length;i>u;++u)d[e[u].identifier]=null;var c=n(),l=Date.now();if(1===c.length){if(500>l-M){var s=c[0];o(p,s,d[s.identifier],Math.floor(Math.log(k.k)/Math.LN2)+1),S()}M=l}else if(c.length>1){var s=c[0],f=c[1],h=s[0]-f[0],g=s[1]-f[1];m=h*h+g*g}}function r(){var n,t,e,r,o=ta.touches(p);Dl.call(p);for(var a=0,c=o.length;c>a;++a,r=null)if(e=o[a],r=d[e.identifier]){if(t)break;n=e,t=r}if(r){var s=(s=e[0]-n[0])*s+(s=e[1]-n[1])*s,f=m&&Math.sqrt(s/m);n=[(n[0]+e[0])/2,(n[1]+e[1])/2],t=[(t[0]+r[0])/2,(t[1]+r[1])/2],u(f*g)}M=null,i(n,t),l(v)}function a(){if(ta.event.touches.length){for(var t=ta.event.changedTouches,e=0,r=t.length;r>e;++e)delete d[t[e].identifier];for(var u in d)return void n()}ta.selectAll(_).on(y,null),w.on(q,f).on(R,h),E(),s(v)}var g,p=this,v=D.of(p,arguments),d={},m=0,y=".zoom-"+ta.event.changedTouches[0].identifier,x="touchmove"+y,b="touchend"+y,_=[],w=ta.select(p),E=W(p);t(),c(v),w.on(q,null).on(R,t)}function g(){var n=D.of(this,arguments);y?clearTimeout(y):(v=e(d=m||ta.mouse(this)),Dl.call(this),c(n)),y=setTimeout(function(){y=null,s(n)},50),S(),u(Math.pow(2,.002*Ha())*k.k),i(d,v),l(n)}function p(){var n=ta.mouse(this),t=Math.log(k.k)/Math.LN2;o(this,n,e(n),ta.event.shiftKey?Math.ceil(t)-1:Math.floor(t)+1)}var v,d,m,y,M,x,b,_,w,k={x:0,y:0,k:1},A=[960,500],N=Ia,C=250,z=0,q="mousedown.zoom",L="mousemove.zoom",T="mouseup.zoom",R="touchstart.zoom",D=E(n,"zoomstart","zoom","zoomend");return Oa||(Oa="onwheel"in ua?(Ha=function(){return-ta.event.deltaY*(ta.event.deltaMode?120:1)},"wheel"):"onmousewheel"in ua?(Ha=function(){return ta.event.wheelDelta},"mousewheel"):(Ha=function(){return-ta.event.detail},"MozMousePixelScroll")),n.event=function(n){n.each(function(){var n=D.of(this,arguments),t=k;Tl?ta.select(this).transition().each("start.zoom",function(){k=this.__chart__||{x:0,y:0,k:1},c(n)}).tween("zoom:zoom",function(){var e=A[0],r=A[1],u=d?d[0]:e/2,i=d?d[1]:r/2,o=ta.interpolateZoom([(u-k.x)/k.k,(i-k.y)/k.k,e/k.k],[(u-t.x)/t.k,(i-t.y)/t.k,e/t.k]);return function(t){var r=o(t),a=e/r[2];this.__chart__=k={x:u-r[0]*a,y:i-r[1]*a,k:a},l(n)}}).each("interrupt.zoom",function(){s(n)}).each("end.zoom",function(){s(n)}):(this.__chart__=k,c(n),l(n),s(n))})},n.translate=function(t){return arguments.length?(k={x:+t[0],y:+t[1],k:k.k},a(),n):[k.x,k.y]},n.scale=function(t){return arguments.length?(k={x:k.x,y:k.y,k:+t},a(),n):k.k},n.scaleExtent=function(t){return arguments.length?(N=null==t?Ia:[+t[0],+t[1]],n):N},n.center=function(t){return arguments.length?(m=t&&[+t[0],+t[1]],n):m},n.size=function(t){return arguments.length?(A=t&&[+t[0],+t[1]],n):A},n.duration=function(t){return arguments.length?(C=+t,n):C},n.x=function(t){return arguments.length?(b=t,x=t.copy(),k={x:0,y:0,k:1},n):b},n.y=function(t){return arguments.length?(w=t,_=t.copy(),k={x:0,y:0,k:1},n):w},ta.rebind(n,D,"on")};var Ha,Oa,Ia=[0,1/0];ta.color=ot,ot.prototype.toString=function(){return this.rgb()+""},ta.hsl=at;var Ya=at.prototype=new ot;Ya.brighter=function(n){return n=Math.pow(.7,arguments.length?n:1),new at(this.h,this.s,this.l/n)},Ya.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),new at(this.h,this.s,n*this.l)},Ya.rgb=function(){return ct(this.h,this.s,this.l)},ta.hcl=lt;var Za=lt.prototype=new ot;Za.brighter=function(n){return new lt(this.h,this.c,Math.min(100,this.l+Va*(arguments.length?n:1)))},Za.darker=function(n){return new lt(this.h,this.c,Math.max(0,this.l-Va*(arguments.length?n:1)))},Za.rgb=function(){return st(this.h,this.c,this.l).rgb()},ta.lab=ft;var Va=18,Xa=.95047,$a=1,Ba=1.08883,Wa=ft.prototype=new ot;Wa.brighter=function(n){return new ft(Math.min(100,this.l+Va*(arguments.length?n:1)),this.a,this.b)},Wa.darker=function(n){return new ft(Math.max(0,this.l-Va*(arguments.length?n:1)),this.a,this.b)},Wa.rgb=function(){return ht(this.l,this.a,this.b)},ta.rgb=mt;var Ja=mt.prototype=new ot;Ja.brighter=function(n){n=Math.pow(.7,arguments.length?n:1);var t=this.r,e=this.g,r=this.b,u=30;return t||e||r?(t&&u>t&&(t=u),e&&u>e&&(e=u),r&&u>r&&(r=u),new mt(Math.min(255,t/n),Math.min(255,e/n),Math.min(255,r/n))):new mt(u,u,u)},Ja.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),new mt(n*this.r,n*this.g,n*this.b)},Ja.hsl=function(){return _t(this.r,this.g,this.b)},Ja.toString=function(){return"#"+xt(this.r)+xt(this.g)+xt(this.b)};var Ga=ta.map({aliceblue:15792383,antiquewhite:16444375,aqua:65535,aquamarine:8388564,azure:15794175,beige:16119260,bisque:16770244,black:0,blanchedalmond:16772045,blue:255,blueviolet:9055202,brown:10824234,burlywood:14596231,cadetblue:6266528,chartreuse:8388352,chocolate:13789470,coral:16744272,cornflowerblue:6591981,cornsilk:16775388,crimson:14423100,cyan:65535,darkblue:139,darkcyan:35723,darkgoldenrod:12092939,darkgray:11119017,darkgreen:25600,darkgrey:11119017,darkkhaki:12433259,darkmagenta:9109643,darkolivegreen:5597999,darkorange:16747520,darkorchid:10040012,darkred:9109504,darksalmon:15308410,darkseagreen:9419919,darkslateblue:4734347,darkslategray:3100495,darkslategrey:3100495,darkturquoise:52945,darkviolet:9699539,deeppink:16716947,deepskyblue:49151,dimgray:6908265,dimgrey:6908265,dodgerblue:2003199,firebrick:11674146,floralwhite:16775920,forestgreen:2263842,fuchsia:16711935,gainsboro:14474460,ghostwhite:16316671,gold:16766720,goldenrod:14329120,gray:8421504,green:32768,greenyellow:11403055,grey:8421504,honeydew:15794160,hotpink:16738740,indianred:13458524,indigo:4915330,ivory:16777200,khaki:15787660,lavender:15132410,lavenderblush:16773365,lawngreen:8190976,lemonchiffon:16775885,lightblue:11393254,lightcoral:15761536,lightcyan:14745599,lightgoldenrodyellow:16448210,lightgray:13882323,lightgreen:9498256,lightgrey:13882323,lightpink:16758465,lightsalmon:16752762,lightseagreen:2142890,lightskyblue:8900346,lightslategray:7833753,lightslategrey:7833753,lightsteelblue:11584734,lightyellow:16777184,lime:65280,limegreen:3329330,linen:16445670,magenta:16711935,maroon:8388608,mediumaquamarine:6737322,mediumblue:205,mediumorchid:12211667,mediumpurple:9662683,mediumseagreen:3978097,mediumslateblue:8087790,mediumspringgreen:64154,mediumturquoise:4772300,mediumvioletred:13047173,midnightblue:1644912,mintcream:16121850,mistyrose:16770273,moccasin:16770229,navajowhite:16768685,navy:128,oldlace:16643558,olive:8421376,olivedrab:7048739,orange:16753920,orangered:16729344,orchid:14315734,palegoldenrod:15657130,palegreen:10025880,paleturquoise:11529966,palevioletred:14381203,papayawhip:16773077,peachpuff:16767673,peru:13468991,pink:16761035,plum:14524637,powderblue:11591910,purple:8388736,rebeccapurple:6697881,red:16711680,rosybrown:12357519,royalblue:4286945,saddlebrown:9127187,salmon:16416882,sandybrown:16032864,seagreen:3050327,seashell:16774638,sienna:10506797,silver:12632256,skyblue:8900331,slateblue:6970061,slategray:7372944,slategrey:7372944,snow:16775930,springgreen:65407,steelblue:4620980,tan:13808780,teal:32896,thistle:14204888,tomato:16737095,turquoise:4251856,violet:15631086,wheat:16113331,white:16777215,whitesmoke:16119285,yellow:16776960,yellowgreen:10145074});Ga.forEach(function(n,t){Ga.set(n,yt(t))}),ta.functor=Et,ta.xhr=At(y),ta.dsv=function(n,t){function e(n,e,i){arguments.length<3&&(i=e,e=null);var o=Nt(n,t,null==e?r:u(e),i);return o.row=function(n){return arguments.length?o.response(null==(e=n)?r:u(n)):e},o}function r(n){return e.parse(n.responseText)}function u(n){return function(t){return e.parse(t.responseText,n)}}function i(t){return t.map(o).join(n)}function o(n){return a.test(n)?'"'+n.replace(/\"/g,'""')+'"':n}var a=new RegExp('["'+n+"\n]"),c=n.charCodeAt(0);return e.parse=function(n,t){var r;return e.parseRows(n,function(n,e){if(r)return r(n,e-1);var u=new Function("d","return {"+n.map(function(n,t){return JSON.stringify(n)+": d["+t+"]"}).join(",")+"}");r=t?function(n,e){return t(u(n),e)}:u})},e.parseRows=function(n,t){function e(){if(s>=l)return o;if(u)return u=!1,i;var t=s;if(34===n.charCodeAt(t)){for(var e=t;e++s;){var r=n.charCodeAt(s++),a=1;if(10===r)u=!0;else if(13===r)u=!0,10===n.charCodeAt(s)&&(++s,++a);else if(r!==c)continue;return n.slice(t,s-a)}return n.slice(t)}for(var r,u,i={},o={},a=[],l=n.length,s=0,f=0;(r=e())!==o;){for(var h=[];r!==i&&r!==o;)h.push(r),r=e();t&&null==(h=t(h,f++))||a.push(h)}return a},e.format=function(t){if(Array.isArray(t[0]))return e.formatRows(t);var r=new m,u=[];return t.forEach(function(n){for(var t in n)r.has(t)||u.push(r.add(t))}),[u.map(o).join(n)].concat(t.map(function(t){return u.map(function(n){return o(t[n])}).join(n)})).join("\n")},e.formatRows=function(n){return n.map(i).join("\n")},e},ta.csv=ta.dsv(",","text/csv"),ta.tsv=ta.dsv(" ","text/tab-separated-values");var Ka,Qa,nc,tc,ec,rc=this[x(this,"requestAnimationFrame")]||function(n){setTimeout(n,17)};ta.timer=function(n,t,e){var r=arguments.length;2>r&&(t=0),3>r&&(e=Date.now());var u=e+t,i={c:n,t:u,f:!1,n:null};Qa?Qa.n=i:Ka=i,Qa=i,nc||(tc=clearTimeout(tc),nc=1,rc(qt))},ta.timer.flush=function(){Lt(),Tt()},ta.round=function(n,t){return t?Math.round(n*(t=Math.pow(10,t)))/t:Math.round(n)};var uc=["y","z","a","f","p","n","\xb5","m","","k","M","G","T","P","E","Z","Y"].map(Dt);ta.formatPrefix=function(n,t){var e=0;return n&&(0>n&&(n*=-1),t&&(n=ta.round(n,Rt(n,t))),e=1+Math.floor(1e-12+Math.log(n)/Math.LN10),e=Math.max(-24,Math.min(24,3*Math.floor((e-1)/3)))),uc[8+e/3]};var ic=/(?:([^{])?([<>=^]))?([+\- ])?([$#])?(0)?(\d+)?(,)?(\.-?\d+)?([a-z%])?/i,oc=ta.map({b:function(n){return n.toString(2)},c:function(n){return String.fromCharCode(n)},o:function(n){return n.toString(8)},x:function(n){return n.toString(16)},X:function(n){return n.toString(16).toUpperCase()},g:function(n,t){return n.toPrecision(t)},e:function(n,t){return n.toExponential(t)},f:function(n,t){return n.toFixed(t)},r:function(n,t){return(n=ta.round(n,Rt(n,t))).toFixed(Math.max(0,Math.min(20,Rt(n*(1+1e-15),t))))}}),ac=ta.time={},cc=Date;jt.prototype={getDate:function(){return this._.getUTCDate()},getDay:function(){return this._.getUTCDay()},getFullYear:function(){return this._.getUTCFullYear()},getHours:function(){return this._.getUTCHours()},getMilliseconds:function(){return this._.getUTCMilliseconds()},getMinutes:function(){return this._.getUTCMinutes()},getMonth:function(){return this._.getUTCMonth()},getSeconds:function(){return this._.getUTCSeconds()},getTime:function(){return this._.getTime()},getTimezoneOffset:function(){return 0},valueOf:function(){return this._.valueOf()},setDate:function(){lc.setUTCDate.apply(this._,arguments)},setDay:function(){lc.setUTCDay.apply(this._,arguments)},setFullYear:function(){lc.setUTCFullYear.apply(this._,arguments)},setHours:function(){lc.setUTCHours.apply(this._,arguments)},setMilliseconds:function(){lc.setUTCMilliseconds.apply(this._,arguments)},setMinutes:function(){lc.setUTCMinutes.apply(this._,arguments)},setMonth:function(){lc.setUTCMonth.apply(this._,arguments)},setSeconds:function(){lc.setUTCSeconds.apply(this._,arguments)},setTime:function(){lc.setTime.apply(this._,arguments)}};var lc=Date.prototype;ac.year=Ft(function(n){return n=ac.day(n),n.setMonth(0,1),n},function(n,t){n.setFullYear(n.getFullYear()+t)},function(n){return n.getFullYear()}),ac.years=ac.year.range,ac.years.utc=ac.year.utc.range,ac.day=Ft(function(n){var t=new cc(2e3,0);return t.setFullYear(n.getFullYear(),n.getMonth(),n.getDate()),t},function(n,t){n.setDate(n.getDate()+t)},function(n){return n.getDate()-1}),ac.days=ac.day.range,ac.days.utc=ac.day.utc.range,ac.dayOfYear=function(n){var t=ac.year(n);return Math.floor((n-t-6e4*(n.getTimezoneOffset()-t.getTimezoneOffset()))/864e5)},["sunday","monday","tuesday","wednesday","thursday","friday","saturday"].forEach(function(n,t){t=7-t;var e=ac[n]=Ft(function(n){return(n=ac.day(n)).setDate(n.getDate()-(n.getDay()+t)%7),n},function(n,t){n.setDate(n.getDate()+7*Math.floor(t))},function(n){var e=ac.year(n).getDay();return Math.floor((ac.dayOfYear(n)+(e+t)%7)/7)-(e!==t)});ac[n+"s"]=e.range,ac[n+"s"].utc=e.utc.range,ac[n+"OfYear"]=function(n){var e=ac.year(n).getDay();return Math.floor((ac.dayOfYear(n)+(e+t)%7)/7)}}),ac.week=ac.sunday,ac.weeks=ac.sunday.range,ac.weeks.utc=ac.sunday.utc.range,ac.weekOfYear=ac.sundayOfYear;var sc={"-":"",_:" ",0:"0"},fc=/^\s*\d+/,hc=/^%/;ta.locale=function(n){return{numberFormat:Pt(n),timeFormat:Ot(n)}};var gc=ta.locale({decimal:".",thousands:",",grouping:[3],currency:["$",""],dateTime:"%a %b %e %X %Y",date:"%m/%d/%Y",time:"%H:%M:%S",periods:["AM","PM"],days:["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],shortDays:["Sun","Mon","Tue","Wed","Thu","Fri","Sat"],months:["January","February","March","April","May","June","July","August","September","October","November","December"],shortMonths:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]});ta.format=gc.numberFormat,ta.geo={},ce.prototype={s:0,t:0,add:function(n){le(n,this.t,pc),le(pc.s,this.s,this),this.s?this.t+=pc.t:this.s=pc.t +},reset:function(){this.s=this.t=0},valueOf:function(){return this.s}};var pc=new ce;ta.geo.stream=function(n,t){n&&vc.hasOwnProperty(n.type)?vc[n.type](n,t):se(n,t)};var vc={Feature:function(n,t){se(n.geometry,t)},FeatureCollection:function(n,t){for(var e=n.features,r=-1,u=e.length;++rn?4*qa+n:n,Mc.lineStart=Mc.lineEnd=Mc.point=b}};ta.geo.bounds=function(){function n(n,t){M.push(x=[s=n,h=n]),f>t&&(f=t),t>g&&(g=t)}function t(t,e){var r=pe([t*Da,e*Da]);if(m){var u=de(m,r),i=[u[1],-u[0],0],o=de(i,u);Me(o),o=xe(o);var c=t-p,l=c>0?1:-1,v=o[0]*Pa*l,d=ga(c)>180;if(d^(v>l*p&&l*t>v)){var y=o[1]*Pa;y>g&&(g=y)}else if(v=(v+360)%360-180,d^(v>l*p&&l*t>v)){var y=-o[1]*Pa;f>y&&(f=y)}else f>e&&(f=e),e>g&&(g=e);d?p>t?a(s,t)>a(s,h)&&(h=t):a(t,h)>a(s,h)&&(s=t):h>=s?(s>t&&(s=t),t>h&&(h=t)):t>p?a(s,t)>a(s,h)&&(h=t):a(t,h)>a(s,h)&&(s=t)}else n(t,e);m=r,p=t}function e(){b.point=t}function r(){x[0]=s,x[1]=h,b.point=n,m=null}function u(n,e){if(m){var r=n-p;y+=ga(r)>180?r+(r>0?360:-360):r}else v=n,d=e;Mc.point(n,e),t(n,e)}function i(){Mc.lineStart()}function o(){u(v,d),Mc.lineEnd(),ga(y)>Ca&&(s=-(h=180)),x[0]=s,x[1]=h,m=null}function a(n,t){return(t-=n)<0?t+360:t}function c(n,t){return n[0]-t[0]}function l(n,t){return t[0]<=t[1]?t[0]<=n&&n<=t[1]:nyc?(s=-(h=180),f=-(g=90)):y>Ca?g=90:-Ca>y&&(f=-90),x[0]=s,x[1]=h}};return function(n){g=h=-(s=f=1/0),M=[],ta.geo.stream(n,b);var t=M.length;if(t){M.sort(c);for(var e,r=1,u=M[0],i=[u];t>r;++r)e=M[r],l(e[0],u)||l(e[1],u)?(a(u[0],e[1])>a(u[0],u[1])&&(u[1]=e[1]),a(e[0],u[1])>a(u[0],u[1])&&(u[0]=e[0])):i.push(u=e);for(var o,e,p=-1/0,t=i.length-1,r=0,u=i[t];t>=r;u=e,++r)e=i[r],(o=a(u[1],e[0]))>p&&(p=o,s=e[0],h=u[1])}return M=x=null,1/0===s||1/0===f?[[0/0,0/0],[0/0,0/0]]:[[s,f],[h,g]]}}(),ta.geo.centroid=function(n){xc=bc=_c=wc=Sc=kc=Ec=Ac=Nc=Cc=zc=0,ta.geo.stream(n,qc);var t=Nc,e=Cc,r=zc,u=t*t+e*e+r*r;return za>u&&(t=kc,e=Ec,r=Ac,Ca>bc&&(t=_c,e=wc,r=Sc),u=t*t+e*e+r*r,za>u)?[0/0,0/0]:[Math.atan2(e,t)*Pa,tt(r/Math.sqrt(u))*Pa]};var xc,bc,_c,wc,Sc,kc,Ec,Ac,Nc,Cc,zc,qc={sphere:b,point:_e,lineStart:Se,lineEnd:ke,polygonStart:function(){qc.lineStart=Ee},polygonEnd:function(){qc.lineStart=Se}},Lc=Le(Ne,Pe,je,[-qa,-qa/2]),Tc=1e9;ta.geo.clipExtent=function(){var n,t,e,r,u,i,o={stream:function(n){return u&&(u.valid=!1),u=i(n),u.valid=!0,u},extent:function(a){return arguments.length?(i=Ie(n=+a[0][0],t=+a[0][1],e=+a[1][0],r=+a[1][1]),u&&(u.valid=!1,u=null),o):[[n,t],[e,r]]}};return o.extent([[0,0],[960,500]])},(ta.geo.conicEqualArea=function(){return Ye(Ze)}).raw=Ze,ta.geo.albers=function(){return ta.geo.conicEqualArea().rotate([96,0]).center([-.6,38.7]).parallels([29.5,45.5]).scale(1070)},ta.geo.albersUsa=function(){function n(n){var i=n[0],o=n[1];return t=null,e(i,o),t||(r(i,o),t)||u(i,o),t}var t,e,r,u,i=ta.geo.albers(),o=ta.geo.conicEqualArea().rotate([154,0]).center([-2,58.5]).parallels([55,65]),a=ta.geo.conicEqualArea().rotate([157,0]).center([-3,19.9]).parallels([8,18]),c={point:function(n,e){t=[n,e]}};return n.invert=function(n){var t=i.scale(),e=i.translate(),r=(n[0]-e[0])/t,u=(n[1]-e[1])/t;return(u>=.12&&.234>u&&r>=-.425&&-.214>r?o:u>=.166&&.234>u&&r>=-.214&&-.115>r?a:i).invert(n)},n.stream=function(n){var t=i.stream(n),e=o.stream(n),r=a.stream(n);return{point:function(n,u){t.point(n,u),e.point(n,u),r.point(n,u)},sphere:function(){t.sphere(),e.sphere(),r.sphere()},lineStart:function(){t.lineStart(),e.lineStart(),r.lineStart()},lineEnd:function(){t.lineEnd(),e.lineEnd(),r.lineEnd()},polygonStart:function(){t.polygonStart(),e.polygonStart(),r.polygonStart()},polygonEnd:function(){t.polygonEnd(),e.polygonEnd(),r.polygonEnd()}}},n.precision=function(t){return arguments.length?(i.precision(t),o.precision(t),a.precision(t),n):i.precision()},n.scale=function(t){return arguments.length?(i.scale(t),o.scale(.35*t),a.scale(t),n.translate(i.translate())):i.scale()},n.translate=function(t){if(!arguments.length)return i.translate();var l=i.scale(),s=+t[0],f=+t[1];return e=i.translate(t).clipExtent([[s-.455*l,f-.238*l],[s+.455*l,f+.238*l]]).stream(c).point,r=o.translate([s-.307*l,f+.201*l]).clipExtent([[s-.425*l+Ca,f+.12*l+Ca],[s-.214*l-Ca,f+.234*l-Ca]]).stream(c).point,u=a.translate([s-.205*l,f+.212*l]).clipExtent([[s-.214*l+Ca,f+.166*l+Ca],[s-.115*l-Ca,f+.234*l-Ca]]).stream(c).point,n},n.scale(1070)};var Rc,Dc,Pc,Uc,jc,Fc,Hc={point:b,lineStart:b,lineEnd:b,polygonStart:function(){Dc=0,Hc.lineStart=Ve},polygonEnd:function(){Hc.lineStart=Hc.lineEnd=Hc.point=b,Rc+=ga(Dc/2)}},Oc={point:Xe,lineStart:b,lineEnd:b,polygonStart:b,polygonEnd:b},Ic={point:We,lineStart:Je,lineEnd:Ge,polygonStart:function(){Ic.lineStart=Ke},polygonEnd:function(){Ic.point=We,Ic.lineStart=Je,Ic.lineEnd=Ge}};ta.geo.path=function(){function n(n){return n&&("function"==typeof a&&i.pointRadius(+a.apply(this,arguments)),o&&o.valid||(o=u(i)),ta.geo.stream(n,o)),i.result()}function t(){return o=null,n}var e,r,u,i,o,a=4.5;return n.area=function(n){return Rc=0,ta.geo.stream(n,u(Hc)),Rc},n.centroid=function(n){return _c=wc=Sc=kc=Ec=Ac=Nc=Cc=zc=0,ta.geo.stream(n,u(Ic)),zc?[Nc/zc,Cc/zc]:Ac?[kc/Ac,Ec/Ac]:Sc?[_c/Sc,wc/Sc]:[0/0,0/0]},n.bounds=function(n){return jc=Fc=-(Pc=Uc=1/0),ta.geo.stream(n,u(Oc)),[[Pc,Uc],[jc,Fc]]},n.projection=function(n){return arguments.length?(u=(e=n)?n.stream||tr(n):y,t()):e},n.context=function(n){return arguments.length?(i=null==(r=n)?new $e:new Qe(n),"function"!=typeof a&&i.pointRadius(a),t()):r},n.pointRadius=function(t){return arguments.length?(a="function"==typeof t?t:(i.pointRadius(+t),+t),n):a},n.projection(ta.geo.albersUsa()).context(null)},ta.geo.transform=function(n){return{stream:function(t){var e=new er(t);for(var r in n)e[r]=n[r];return e}}},er.prototype={point:function(n,t){this.stream.point(n,t)},sphere:function(){this.stream.sphere()},lineStart:function(){this.stream.lineStart()},lineEnd:function(){this.stream.lineEnd()},polygonStart:function(){this.stream.polygonStart()},polygonEnd:function(){this.stream.polygonEnd()}},ta.geo.projection=ur,ta.geo.projectionMutator=ir,(ta.geo.equirectangular=function(){return ur(ar)}).raw=ar.invert=ar,ta.geo.rotation=function(n){function t(t){return t=n(t[0]*Da,t[1]*Da),t[0]*=Pa,t[1]*=Pa,t}return n=lr(n[0]%360*Da,n[1]*Da,n.length>2?n[2]*Da:0),t.invert=function(t){return t=n.invert(t[0]*Da,t[1]*Da),t[0]*=Pa,t[1]*=Pa,t},t},cr.invert=ar,ta.geo.circle=function(){function n(){var n="function"==typeof r?r.apply(this,arguments):r,t=lr(-n[0]*Da,-n[1]*Da,0).invert,u=[];return e(null,null,1,{point:function(n,e){u.push(n=t(n,e)),n[0]*=Pa,n[1]*=Pa}}),{type:"Polygon",coordinates:[u]}}var t,e,r=[0,0],u=6;return n.origin=function(t){return arguments.length?(r=t,n):r},n.angle=function(r){return arguments.length?(e=gr((t=+r)*Da,u*Da),n):t},n.precision=function(r){return arguments.length?(e=gr(t*Da,(u=+r)*Da),n):u},n.angle(90)},ta.geo.distance=function(n,t){var e,r=(t[0]-n[0])*Da,u=n[1]*Da,i=t[1]*Da,o=Math.sin(r),a=Math.cos(r),c=Math.sin(u),l=Math.cos(u),s=Math.sin(i),f=Math.cos(i);return Math.atan2(Math.sqrt((e=f*o)*e+(e=l*s-c*f*a)*e),c*s+l*f*a)},ta.geo.graticule=function(){function n(){return{type:"MultiLineString",coordinates:t()}}function t(){return ta.range(Math.ceil(i/d)*d,u,d).map(h).concat(ta.range(Math.ceil(l/m)*m,c,m).map(g)).concat(ta.range(Math.ceil(r/p)*p,e,p).filter(function(n){return ga(n%d)>Ca}).map(s)).concat(ta.range(Math.ceil(a/v)*v,o,v).filter(function(n){return ga(n%m)>Ca}).map(f))}var e,r,u,i,o,a,c,l,s,f,h,g,p=10,v=p,d=90,m=360,y=2.5;return n.lines=function(){return t().map(function(n){return{type:"LineString",coordinates:n}})},n.outline=function(){return{type:"Polygon",coordinates:[h(i).concat(g(c).slice(1),h(u).reverse().slice(1),g(l).reverse().slice(1))]}},n.extent=function(t){return arguments.length?n.majorExtent(t).minorExtent(t):n.minorExtent()},n.majorExtent=function(t){return arguments.length?(i=+t[0][0],u=+t[1][0],l=+t[0][1],c=+t[1][1],i>u&&(t=i,i=u,u=t),l>c&&(t=l,l=c,c=t),n.precision(y)):[[i,l],[u,c]]},n.minorExtent=function(t){return arguments.length?(r=+t[0][0],e=+t[1][0],a=+t[0][1],o=+t[1][1],r>e&&(t=r,r=e,e=t),a>o&&(t=a,a=o,o=t),n.precision(y)):[[r,a],[e,o]]},n.step=function(t){return arguments.length?n.majorStep(t).minorStep(t):n.minorStep()},n.majorStep=function(t){return arguments.length?(d=+t[0],m=+t[1],n):[d,m]},n.minorStep=function(t){return arguments.length?(p=+t[0],v=+t[1],n):[p,v]},n.precision=function(t){return arguments.length?(y=+t,s=vr(a,o,90),f=dr(r,e,y),h=vr(l,c,90),g=dr(i,u,y),n):y},n.majorExtent([[-180,-90+Ca],[180,90-Ca]]).minorExtent([[-180,-80-Ca],[180,80+Ca]])},ta.geo.greatArc=function(){function n(){return{type:"LineString",coordinates:[t||r.apply(this,arguments),e||u.apply(this,arguments)]}}var t,e,r=mr,u=yr;return n.distance=function(){return ta.geo.distance(t||r.apply(this,arguments),e||u.apply(this,arguments))},n.source=function(e){return arguments.length?(r=e,t="function"==typeof e?null:e,n):r},n.target=function(t){return arguments.length?(u=t,e="function"==typeof t?null:t,n):u},n.precision=function(){return arguments.length?n:0},n},ta.geo.interpolate=function(n,t){return Mr(n[0]*Da,n[1]*Da,t[0]*Da,t[1]*Da)},ta.geo.length=function(n){return Yc=0,ta.geo.stream(n,Zc),Yc};var Yc,Zc={sphere:b,point:b,lineStart:xr,lineEnd:b,polygonStart:b,polygonEnd:b},Vc=br(function(n){return Math.sqrt(2/(1+n))},function(n){return 2*Math.asin(n/2)});(ta.geo.azimuthalEqualArea=function(){return ur(Vc)}).raw=Vc;var Xc=br(function(n){var t=Math.acos(n);return t&&t/Math.sin(t)},y);(ta.geo.azimuthalEquidistant=function(){return ur(Xc)}).raw=Xc,(ta.geo.conicConformal=function(){return Ye(_r)}).raw=_r,(ta.geo.conicEquidistant=function(){return Ye(wr)}).raw=wr;var $c=br(function(n){return 1/n},Math.atan);(ta.geo.gnomonic=function(){return ur($c)}).raw=$c,Sr.invert=function(n,t){return[n,2*Math.atan(Math.exp(t))-Ra]},(ta.geo.mercator=function(){return kr(Sr)}).raw=Sr;var Bc=br(function(){return 1},Math.asin);(ta.geo.orthographic=function(){return ur(Bc)}).raw=Bc;var Wc=br(function(n){return 1/(1+n)},function(n){return 2*Math.atan(n)});(ta.geo.stereographic=function(){return ur(Wc)}).raw=Wc,Er.invert=function(n,t){return[-t,2*Math.atan(Math.exp(n))-Ra]},(ta.geo.transverseMercator=function(){var n=kr(Er),t=n.center,e=n.rotate;return n.center=function(n){return n?t([-n[1],n[0]]):(n=t(),[n[1],-n[0]])},n.rotate=function(n){return n?e([n[0],n[1],n.length>2?n[2]+90:90]):(n=e(),[n[0],n[1],n[2]-90])},e([0,0,90])}).raw=Er,ta.geom={},ta.geom.hull=function(n){function t(n){if(n.length<3)return[];var t,u=Et(e),i=Et(r),o=n.length,a=[],c=[];for(t=0;o>t;t++)a.push([+u.call(this,n[t],t),+i.call(this,n[t],t),t]);for(a.sort(zr),t=0;o>t;t++)c.push([a[t][0],-a[t][1]]);var l=Cr(a),s=Cr(c),f=s[0]===l[0],h=s[s.length-1]===l[l.length-1],g=[];for(t=l.length-1;t>=0;--t)g.push(n[a[l[t]][2]]);for(t=+f;t=r&&l.x<=i&&l.y>=u&&l.y<=o?[[r,o],[i,o],[i,u],[r,u]]:[];s.point=n[a]}),t}function e(n){return n.map(function(n,t){return{x:Math.round(i(n,t)/Ca)*Ca,y:Math.round(o(n,t)/Ca)*Ca,i:t}})}var r=Ar,u=Nr,i=r,o=u,a=ul;return n?t(n):(t.links=function(n){return iu(e(n)).edges.filter(function(n){return n.l&&n.r}).map(function(t){return{source:n[t.l.i],target:n[t.r.i]}})},t.triangles=function(n){var t=[];return iu(e(n)).cells.forEach(function(e,r){for(var u,i,o=e.site,a=e.edges.sort(Yr),c=-1,l=a.length,s=a[l-1].edge,f=s.l===o?s.r:s.l;++c=l,h=r>=s,g=h<<1|f;n.leaf=!1,n=n.nodes[g]||(n.nodes[g]=su()),f?u=l:a=l,h?o=s:c=s,i(n,t,e,r,u,o,a,c)}var s,f,h,g,p,v,d,m,y,M=Et(a),x=Et(c);if(null!=t)v=t,d=e,m=r,y=u;else if(m=y=-(v=d=1/0),f=[],h=[],p=n.length,o)for(g=0;p>g;++g)s=n[g],s.xm&&(m=s.x),s.y>y&&(y=s.y),f.push(s.x),h.push(s.y);else for(g=0;p>g;++g){var b=+M(s=n[g],g),_=+x(s,g);v>b&&(v=b),d>_&&(d=_),b>m&&(m=b),_>y&&(y=_),f.push(b),h.push(_)}var w=m-v,S=y-d;w>S?y=d+w:m=v+S;var k=su();if(k.add=function(n){i(k,n,+M(n,++g),+x(n,g),v,d,m,y)},k.visit=function(n){fu(n,k,v,d,m,y)},k.find=function(n){return hu(k,n[0],n[1],v,d,m,y)},g=-1,null==t){for(;++g=0?n.slice(0,t):n,r=t>=0?n.slice(t+1):"in";return e=cl.get(e)||al,r=ll.get(r)||y,Mu(r(e.apply(null,ea.call(arguments,1))))},ta.interpolateHcl=Lu,ta.interpolateHsl=Tu,ta.interpolateLab=Ru,ta.interpolateRound=Du,ta.transform=function(n){var t=ua.createElementNS(ta.ns.prefix.svg,"g");return(ta.transform=function(n){if(null!=n){t.setAttribute("transform",n);var e=t.transform.baseVal.consolidate()}return new Pu(e?e.matrix:sl)})(n)},Pu.prototype.toString=function(){return"translate("+this.translate+")rotate("+this.rotate+")skewX("+this.skew+")scale("+this.scale+")"};var sl={a:1,b:0,c:0,d:1,e:0,f:0};ta.interpolateTransform=Hu,ta.layout={},ta.layout.bundle=function(){return function(n){for(var t=[],e=-1,r=n.length;++ea*a/d){if(p>c){var l=t.charge/c;n.px-=i*l,n.py-=o*l}return!0}if(t.point&&c&&p>c){var l=t.pointCharge/c;n.px-=i*l,n.py-=o*l}}return!t.charge}}function t(n){n.px=ta.event.x,n.py=ta.event.y,a.resume()}var e,r,u,i,o,a={},c=ta.dispatch("start","tick","end"),l=[1,1],s=.9,f=fl,h=hl,g=-30,p=gl,v=.1,d=.64,m=[],M=[];return a.tick=function(){if((r*=.99)<.005)return c.end({type:"end",alpha:r=0}),!0;var t,e,a,f,h,p,d,y,x,b=m.length,_=M.length;for(e=0;_>e;++e)a=M[e],f=a.source,h=a.target,y=h.x-f.x,x=h.y-f.y,(p=y*y+x*x)&&(p=r*i[e]*((p=Math.sqrt(p))-u[e])/p,y*=p,x*=p,h.x-=y*(d=f.weight/(h.weight+f.weight)),h.y-=x*d,f.x+=y*(d=1-d),f.y+=x*d);if((d=r*v)&&(y=l[0]/2,x=l[1]/2,e=-1,d))for(;++e0?n:0:n>0&&(c.start({type:"start",alpha:r=n}),ta.timer(a.tick)),a):r},a.start=function(){function n(n,r){if(!e){for(e=new Array(c),a=0;c>a;++a)e[a]=[];for(a=0;s>a;++a){var u=M[a];e[u.source.index].push(u.target),e[u.target.index].push(u.source)}}for(var i,o=e[t],a=-1,l=o.length;++at;++t)(r=m[t]).index=t,r.weight=0;for(t=0;s>t;++t)r=M[t],"number"==typeof r.source&&(r.source=m[r.source]),"number"==typeof r.target&&(r.target=m[r.target]),++r.source.weight,++r.target.weight;for(t=0;c>t;++t)r=m[t],isNaN(r.x)&&(r.x=n("x",p)),isNaN(r.y)&&(r.y=n("y",v)),isNaN(r.px)&&(r.px=r.x),isNaN(r.py)&&(r.py=r.y);if(u=[],"function"==typeof f)for(t=0;s>t;++t)u[t]=+f.call(this,M[t],t);else for(t=0;s>t;++t)u[t]=f;if(i=[],"function"==typeof h)for(t=0;s>t;++t)i[t]=+h.call(this,M[t],t);else for(t=0;s>t;++t)i[t]=h;if(o=[],"function"==typeof g)for(t=0;c>t;++t)o[t]=+g.call(this,m[t],t);else for(t=0;c>t;++t)o[t]=g;return a.resume()},a.resume=function(){return a.alpha(.1)},a.stop=function(){return a.alpha(0)},a.drag=function(){return e||(e=ta.behavior.drag().origin(y).on("dragstart.force",Xu).on("drag.force",t).on("dragend.force",$u)),arguments.length?void this.on("mouseover.force",Bu).on("mouseout.force",Wu).call(e):e},ta.rebind(a,c,"on")};var fl=20,hl=1,gl=1/0;ta.layout.hierarchy=function(){function n(u){var i,o=[u],a=[];for(u.depth=0;null!=(i=o.pop());)if(a.push(i),(l=e.call(n,i,i.depth))&&(c=l.length)){for(var c,l,s;--c>=0;)o.push(s=l[c]),s.parent=i,s.depth=i.depth+1;r&&(i.value=0),i.children=l}else r&&(i.value=+r.call(n,i,i.depth)||0),delete i.children;return Qu(u,function(n){var e,u;t&&(e=n.children)&&e.sort(t),r&&(u=n.parent)&&(u.value+=n.value)}),a}var t=ei,e=ni,r=ti;return n.sort=function(e){return arguments.length?(t=e,n):t},n.children=function(t){return arguments.length?(e=t,n):e},n.value=function(t){return arguments.length?(r=t,n):r},n.revalue=function(t){return r&&(Ku(t,function(n){n.children&&(n.value=0)}),Qu(t,function(t){var e;t.children||(t.value=+r.call(n,t,t.depth)||0),(e=t.parent)&&(e.value+=t.value)})),t},n},ta.layout.partition=function(){function n(t,e,r,u){var i=t.children;if(t.x=e,t.y=t.depth*u,t.dx=r,t.dy=u,i&&(o=i.length)){var o,a,c,l=-1;for(r=t.value?r/t.value:0;++lf?-1:1),p=(f-c*g)/ta.sum(l),v=ta.range(c),d=[];return null!=e&&v.sort(e===pl?function(n,t){return l[t]-l[n]}:function(n,t){return e(o[n],o[t])}),v.forEach(function(n){d[n]={data:o[n],value:a=l[n],startAngle:s,endAngle:s+=a*p+g,padAngle:h}}),d}var t=Number,e=pl,r=0,u=La,i=0;return n.value=function(e){return arguments.length?(t=e,n):t},n.sort=function(t){return arguments.length?(e=t,n):e},n.startAngle=function(t){return arguments.length?(r=t,n):r},n.endAngle=function(t){return arguments.length?(u=t,n):u},n.padAngle=function(t){return arguments.length?(i=t,n):i},n};var pl={};ta.layout.stack=function(){function n(a,c){if(!(h=a.length))return a;var l=a.map(function(e,r){return t.call(n,e,r)}),s=l.map(function(t){return t.map(function(t,e){return[i.call(n,t,e),o.call(n,t,e)]})}),f=e.call(n,s,c);l=ta.permute(l,f),s=ta.permute(s,f);var h,g,p,v,d=r.call(n,s,c),m=l[0].length;for(p=0;m>p;++p)for(u.call(n,l[0][p],v=d[p],s[0][p][1]),g=1;h>g;++g)u.call(n,l[g][p],v+=s[g-1][p][1],s[g][p][1]);return a}var t=y,e=ai,r=ci,u=oi,i=ui,o=ii;return n.values=function(e){return arguments.length?(t=e,n):t},n.order=function(t){return arguments.length?(e="function"==typeof t?t:vl.get(t)||ai,n):e},n.offset=function(t){return arguments.length?(r="function"==typeof t?t:dl.get(t)||ci,n):r},n.x=function(t){return arguments.length?(i=t,n):i},n.y=function(t){return arguments.length?(o=t,n):o},n.out=function(t){return arguments.length?(u=t,n):u},n};var vl=ta.map({"inside-out":function(n){var t,e,r=n.length,u=n.map(li),i=n.map(si),o=ta.range(r).sort(function(n,t){return u[n]-u[t]}),a=0,c=0,l=[],s=[];for(t=0;r>t;++t)e=o[t],c>a?(a+=i[e],l.push(e)):(c+=i[e],s.push(e));return s.reverse().concat(l)},reverse:function(n){return ta.range(n.length).reverse()},"default":ai}),dl=ta.map({silhouette:function(n){var t,e,r,u=n.length,i=n[0].length,o=[],a=0,c=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];r>a&&(a=r),o.push(r)}for(e=0;i>e;++e)c[e]=(a-o[e])/2;return c},wiggle:function(n){var t,e,r,u,i,o,a,c,l,s=n.length,f=n[0],h=f.length,g=[];for(g[0]=c=l=0,e=1;h>e;++e){for(t=0,u=0;s>t;++t)u+=n[t][e][1];for(t=0,i=0,a=f[e][0]-f[e-1][0];s>t;++t){for(r=0,o=(n[t][e][1]-n[t][e-1][1])/(2*a);t>r;++r)o+=(n[r][e][1]-n[r][e-1][1])/a;i+=o*n[t][e][1]}g[e]=c-=u?i/u*a:0,l>c&&(l=c)}for(e=0;h>e;++e)g[e]-=l;return g},expand:function(n){var t,e,r,u=n.length,i=n[0].length,o=1/u,a=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];if(r)for(t=0;u>t;t++)n[t][e][1]/=r;else for(t=0;u>t;t++)n[t][e][1]=o}for(e=0;i>e;++e)a[e]=0;return a},zero:ci});ta.layout.histogram=function(){function n(n,i){for(var o,a,c=[],l=n.map(e,this),s=r.call(this,l,i),f=u.call(this,s,l,i),i=-1,h=l.length,g=f.length-1,p=t?1:1/h;++i0)for(i=-1;++i=s[0]&&a<=s[1]&&(o=c[ta.bisect(f,a,1,g)-1],o.y+=p,o.push(n[i]));return c}var t=!0,e=Number,r=pi,u=hi;return n.value=function(t){return arguments.length?(e=t,n):e},n.range=function(t){return arguments.length?(r=Et(t),n):r},n.bins=function(t){return arguments.length?(u="number"==typeof t?function(n){return gi(n,t)}:Et(t),n):u},n.frequency=function(e){return arguments.length?(t=!!e,n):t},n},ta.layout.pack=function(){function n(n,i){var o=e.call(this,n,i),a=o[0],c=u[0],l=u[1],s=null==t?Math.sqrt:"function"==typeof t?t:function(){return t};if(a.x=a.y=0,Qu(a,function(n){n.r=+s(n.value)}),Qu(a,Mi),r){var f=r*(t?1:Math.max(2*a.r/c,2*a.r/l))/2;Qu(a,function(n){n.r+=f}),Qu(a,Mi),Qu(a,function(n){n.r-=f})}return _i(a,c/2,l/2,t?1:1/Math.max(2*a.r/c,2*a.r/l)),o}var t,e=ta.layout.hierarchy().sort(vi),r=0,u=[1,1];return n.size=function(t){return arguments.length?(u=t,n):u},n.radius=function(e){return arguments.length?(t=null==e||"function"==typeof e?e:+e,n):t},n.padding=function(t){return arguments.length?(r=+t,n):r},Gu(n,e)},ta.layout.tree=function(){function n(n,u){var s=o.call(this,n,u),f=s[0],h=t(f);if(Qu(h,e),h.parent.m=-h.z,Ku(h,r),l)Ku(f,i);else{var g=f,p=f,v=f;Ku(f,function(n){n.xp.x&&(p=n),n.depth>v.depth&&(v=n)});var d=a(g,p)/2-g.x,m=c[0]/(p.x+a(p,g)/2+d),y=c[1]/(v.depth||1);Ku(f,function(n){n.x=(n.x+d)*m,n.y=n.depth*y})}return s}function t(n){for(var t,e={A:null,children:[n]},r=[e];null!=(t=r.pop());)for(var u,i=t.children,o=0,a=i.length;a>o;++o)r.push((i[o]=u={_:i[o],parent:t,children:(u=i[o].children)&&u.slice()||[],A:null,a:null,z:0,m:0,c:0,s:0,t:null,i:o}).a=u);return e.children[0]}function e(n){var t=n.children,e=n.parent.children,r=n.i?e[n.i-1]:null;if(t.length){Ni(n);var i=(t[0].z+t[t.length-1].z)/2;r?(n.z=r.z+a(n._,r._),n.m=n.z-i):n.z=i}else r&&(n.z=r.z+a(n._,r._));n.parent.A=u(n,r,n.parent.A||e[0])}function r(n){n._.x=n.z+n.parent.m,n.m+=n.parent.m}function u(n,t,e){if(t){for(var r,u=n,i=n,o=t,c=u.parent.children[0],l=u.m,s=i.m,f=o.m,h=c.m;o=Ei(o),u=ki(u),o&&u;)c=ki(c),i=Ei(i),i.a=n,r=o.z+f-u.z-l+a(o._,u._),r>0&&(Ai(Ci(o,n,e),n,r),l+=r,s+=r),f+=o.m,l+=u.m,h+=c.m,s+=i.m;o&&!Ei(i)&&(i.t=o,i.m+=f-s),u&&!ki(c)&&(c.t=u,c.m+=l-h,e=n)}return e}function i(n){n.x*=c[0],n.y=n.depth*c[1]}var o=ta.layout.hierarchy().sort(null).value(null),a=Si,c=[1,1],l=null;return n.separation=function(t){return arguments.length?(a=t,n):a},n.size=function(t){return arguments.length?(l=null==(c=t)?i:null,n):l?null:c},n.nodeSize=function(t){return arguments.length?(l=null==(c=t)?null:i,n):l?c:null},Gu(n,o)},ta.layout.cluster=function(){function n(n,i){var o,a=t.call(this,n,i),c=a[0],l=0;Qu(c,function(n){var t=n.children;t&&t.length?(n.x=qi(t),n.y=zi(t)):(n.x=o?l+=e(n,o):0,n.y=0,o=n)});var s=Li(c),f=Ti(c),h=s.x-e(s,f)/2,g=f.x+e(f,s)/2;return Qu(c,u?function(n){n.x=(n.x-c.x)*r[0],n.y=(c.y-n.y)*r[1]}:function(n){n.x=(n.x-h)/(g-h)*r[0],n.y=(1-(c.y?n.y/c.y:1))*r[1]}),a}var t=ta.layout.hierarchy().sort(null).value(null),e=Si,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Gu(n,t)},ta.layout.treemap=function(){function n(n,t){for(var e,r,u=-1,i=n.length;++ut?0:t),e.area=isNaN(r)||0>=r?0:r}function t(e){var i=e.children;if(i&&i.length){var o,a,c,l=f(e),s=[],h=i.slice(),p=1/0,v="slice"===g?l.dx:"dice"===g?l.dy:"slice-dice"===g?1&e.depth?l.dy:l.dx:Math.min(l.dx,l.dy);for(n(h,l.dx*l.dy/e.value),s.area=0;(c=h.length)>0;)s.push(o=h[c-1]),s.area+=o.area,"squarify"!==g||(a=r(s,v))<=p?(h.pop(),p=a):(s.area-=s.pop().area,u(s,v,l,!1),v=Math.min(l.dx,l.dy),s.length=s.area=0,p=1/0);s.length&&(u(s,v,l,!0),s.length=s.area=0),i.forEach(t)}}function e(t){var r=t.children;if(r&&r.length){var i,o=f(t),a=r.slice(),c=[];for(n(a,o.dx*o.dy/t.value),c.area=0;i=a.pop();)c.push(i),c.area+=i.area,null!=i.z&&(u(c,i.z?o.dx:o.dy,o,!a.length),c.length=c.area=0);r.forEach(e)}}function r(n,t){for(var e,r=n.area,u=0,i=1/0,o=-1,a=n.length;++oe&&(i=e),e>u&&(u=e));return r*=r,t*=t,r?Math.max(t*u*p/r,r/(t*i*p)):1/0}function u(n,t,e,r){var u,i=-1,o=n.length,a=e.x,l=e.y,s=t?c(n.area/t):0;if(t==e.dx){for((r||s>e.dy)&&(s=e.dy);++ie.dx)&&(s=e.dx);++ie&&(t=1),1>e&&(n=0),function(){var e,r,u;do e=2*Math.random()-1,r=2*Math.random()-1,u=e*e+r*r;while(!u||u>1);return n+t*e*Math.sqrt(-2*Math.log(u)/u)}},logNormal:function(){var n=ta.random.normal.apply(ta,arguments);return function(){return Math.exp(n())}},bates:function(n){var t=ta.random.irwinHall(n);return function(){return t()/n}},irwinHall:function(n){return function(){for(var t=0,e=0;n>e;e++)t+=Math.random();return t}}},ta.scale={};var ml={floor:y,ceil:y};ta.scale.linear=function(){return Ii([0,1],[0,1],mu,!1)};var yl={s:1,g:1,p:1,r:1,e:1};ta.scale.log=function(){return Ji(ta.scale.linear().domain([0,1]),10,!0,[1,10])};var Ml=ta.format(".0e"),xl={floor:function(n){return-Math.ceil(-n)},ceil:function(n){return-Math.floor(-n)}};ta.scale.pow=function(){return Gi(ta.scale.linear(),1,[0,1])},ta.scale.sqrt=function(){return ta.scale.pow().exponent(.5)},ta.scale.ordinal=function(){return Qi([],{t:"range",a:[[]]})},ta.scale.category10=function(){return ta.scale.ordinal().range(bl)},ta.scale.category20=function(){return ta.scale.ordinal().range(_l)},ta.scale.category20b=function(){return ta.scale.ordinal().range(wl)},ta.scale.category20c=function(){return ta.scale.ordinal().range(Sl)};var bl=[2062260,16744206,2924588,14034728,9725885,9197131,14907330,8355711,12369186,1556175].map(Mt),_l=[2062260,11454440,16744206,16759672,2924588,10018698,14034728,16750742,9725885,12955861,9197131,12885140,14907330,16234194,8355711,13092807,12369186,14408589,1556175,10410725].map(Mt),wl=[3750777,5395619,7040719,10264286,6519097,9216594,11915115,13556636,9202993,12426809,15186514,15190932,8666169,11356490,14049643,15177372,8077683,10834324,13528509,14589654].map(Mt),Sl=[3244733,7057110,10406625,13032431,15095053,16616764,16625259,16634018,3253076,7652470,10607003,13101504,7695281,10394312,12369372,14342891,6513507,9868950,12434877,14277081].map(Mt);ta.scale.quantile=function(){return no([],[])},ta.scale.quantize=function(){return to(0,1,[0,1])},ta.scale.threshold=function(){return eo([.5],[0,1])},ta.scale.identity=function(){return ro([0,1])},ta.svg={},ta.svg.arc=function(){function n(){var n=Math.max(0,+e.apply(this,arguments)),l=Math.max(0,+r.apply(this,arguments)),s=o.apply(this,arguments)-Ra,f=a.apply(this,arguments)-Ra,h=Math.abs(f-s),g=s>f?0:1;if(n>l&&(p=l,l=n,n=p),h>=Ta)return t(l,g)+(n?t(n,1-g):"")+"Z";var p,v,d,m,y,M,x,b,_,w,S,k,E=0,A=0,N=[];if((m=(+c.apply(this,arguments)||0)/2)&&(d=i===kl?Math.sqrt(n*n+l*l):+i.apply(this,arguments),g||(A*=-1),l&&(A=tt(d/l*Math.sin(m))),n&&(E=tt(d/n*Math.sin(m)))),l){y=l*Math.cos(s+A),M=l*Math.sin(s+A),x=l*Math.cos(f-A),b=l*Math.sin(f-A);var C=Math.abs(f-s-2*A)<=qa?0:1;if(A&&so(y,M,x,b)===g^C){var z=(s+f)/2;y=l*Math.cos(z),M=l*Math.sin(z),x=b=null}}else y=M=0;if(n){_=n*Math.cos(f-E),w=n*Math.sin(f-E),S=n*Math.cos(s+E),k=n*Math.sin(s+E);var q=Math.abs(s-f+2*E)<=qa?0:1;if(E&&so(_,w,S,k)===1-g^q){var L=(s+f)/2;_=n*Math.cos(L),w=n*Math.sin(L),S=k=null}}else _=w=0;if((p=Math.min(Math.abs(l-n)/2,+u.apply(this,arguments)))>.001){v=l>n^g?0:1;var T=null==S?[_,w]:null==x?[y,M]:Lr([y,M],[S,k],[x,b],[_,w]),R=y-T[0],D=M-T[1],P=x-T[0],U=b-T[1],j=1/Math.sin(Math.acos((R*P+D*U)/(Math.sqrt(R*R+D*D)*Math.sqrt(P*P+U*U)))/2),F=Math.sqrt(T[0]*T[0]+T[1]*T[1]);if(null!=x){var H=Math.min(p,(l-F)/(j+1)),O=fo(null==S?[_,w]:[S,k],[y,M],l,H,g),I=fo([x,b],[_,w],l,H,g);p===H?N.push("M",O[0],"A",H,",",H," 0 0,",v," ",O[1],"A",l,",",l," 0 ",1-g^so(O[1][0],O[1][1],I[1][0],I[1][1]),",",g," ",I[1],"A",H,",",H," 0 0,",v," ",I[0]):N.push("M",O[0],"A",H,",",H," 0 1,",v," ",I[0])}else N.push("M",y,",",M);if(null!=S){var Y=Math.min(p,(n-F)/(j-1)),Z=fo([y,M],[S,k],n,-Y,g),V=fo([_,w],null==x?[y,M]:[x,b],n,-Y,g);p===Y?N.push("L",V[0],"A",Y,",",Y," 0 0,",v," ",V[1],"A",n,",",n," 0 ",g^so(V[1][0],V[1][1],Z[1][0],Z[1][1]),",",1-g," ",Z[1],"A",Y,",",Y," 0 0,",v," ",Z[0]):N.push("L",V[0],"A",Y,",",Y," 0 0,",v," ",Z[0])}else N.push("L",_,",",w)}else N.push("M",y,",",M),null!=x&&N.push("A",l,",",l," 0 ",C,",",g," ",x,",",b),N.push("L",_,",",w),null!=S&&N.push("A",n,",",n," 0 ",q,",",1-g," ",S,",",k);return N.push("Z"),N.join("")}function t(n,t){return"M0,"+n+"A"+n+","+n+" 0 1,"+t+" 0,"+-n+"A"+n+","+n+" 0 1,"+t+" 0,"+n}var e=io,r=oo,u=uo,i=kl,o=ao,a=co,c=lo;return n.innerRadius=function(t){return arguments.length?(e=Et(t),n):e},n.outerRadius=function(t){return arguments.length?(r=Et(t),n):r},n.cornerRadius=function(t){return arguments.length?(u=Et(t),n):u},n.padRadius=function(t){return arguments.length?(i=t==kl?kl:Et(t),n):i},n.startAngle=function(t){return arguments.length?(o=Et(t),n):o},n.endAngle=function(t){return arguments.length?(a=Et(t),n):a},n.padAngle=function(t){return arguments.length?(c=Et(t),n):c},n.centroid=function(){var n=(+e.apply(this,arguments)+ +r.apply(this,arguments))/2,t=(+o.apply(this,arguments)+ +a.apply(this,arguments))/2-Ra;return[Math.cos(t)*n,Math.sin(t)*n]},n};var kl="auto";ta.svg.line=function(){return ho(y)};var El=ta.map({linear:go,"linear-closed":po,step:vo,"step-before":mo,"step-after":yo,basis:So,"basis-open":ko,"basis-closed":Eo,bundle:Ao,cardinal:bo,"cardinal-open":Mo,"cardinal-closed":xo,monotone:To});El.forEach(function(n,t){t.key=n,t.closed=/-closed$/.test(n)});var Al=[0,2/3,1/3,0],Nl=[0,1/3,2/3,0],Cl=[0,1/6,2/3,1/6];ta.svg.line.radial=function(){var n=ho(Ro);return n.radius=n.x,delete n.x,n.angle=n.y,delete n.y,n},mo.reverse=yo,yo.reverse=mo,ta.svg.area=function(){return Do(y)},ta.svg.area.radial=function(){var n=Do(Ro);return n.radius=n.x,delete n.x,n.innerRadius=n.x0,delete n.x0,n.outerRadius=n.x1,delete n.x1,n.angle=n.y,delete n.y,n.startAngle=n.y0,delete n.y0,n.endAngle=n.y1,delete n.y1,n},ta.svg.chord=function(){function n(n,a){var c=t(this,i,n,a),l=t(this,o,n,a);return"M"+c.p0+r(c.r,c.p1,c.a1-c.a0)+(e(c,l)?u(c.r,c.p1,c.r,c.p0):u(c.r,c.p1,l.r,l.p0)+r(l.r,l.p1,l.a1-l.a0)+u(l.r,l.p1,c.r,c.p0))+"Z"}function t(n,t,e,r){var u=t.call(n,e,r),i=a.call(n,u,r),o=c.call(n,u,r)-Ra,s=l.call(n,u,r)-Ra;return{r:i,a0:o,a1:s,p0:[i*Math.cos(o),i*Math.sin(o)],p1:[i*Math.cos(s),i*Math.sin(s)]}}function e(n,t){return n.a0==t.a0&&n.a1==t.a1}function r(n,t,e){return"A"+n+","+n+" 0 "+ +(e>qa)+",1 "+t}function u(n,t,e,r){return"Q 0,0 "+r}var i=mr,o=yr,a=Po,c=ao,l=co;return n.radius=function(t){return arguments.length?(a=Et(t),n):a},n.source=function(t){return arguments.length?(i=Et(t),n):i},n.target=function(t){return arguments.length?(o=Et(t),n):o},n.startAngle=function(t){return arguments.length?(c=Et(t),n):c},n.endAngle=function(t){return arguments.length?(l=Et(t),n):l},n},ta.svg.diagonal=function(){function n(n,u){var i=t.call(this,n,u),o=e.call(this,n,u),a=(i.y+o.y)/2,c=[i,{x:i.x,y:a},{x:o.x,y:a},o];return c=c.map(r),"M"+c[0]+"C"+c[1]+" "+c[2]+" "+c[3]}var t=mr,e=yr,r=Uo;return n.source=function(e){return arguments.length?(t=Et(e),n):t},n.target=function(t){return arguments.length?(e=Et(t),n):e},n.projection=function(t){return arguments.length?(r=t,n):r},n},ta.svg.diagonal.radial=function(){var n=ta.svg.diagonal(),t=Uo,e=n.projection;return n.projection=function(n){return arguments.length?e(jo(t=n)):t},n},ta.svg.symbol=function(){function n(n,r){return(zl.get(t.call(this,n,r))||Oo)(e.call(this,n,r))}var t=Ho,e=Fo;return n.type=function(e){return arguments.length?(t=Et(e),n):t},n.size=function(t){return arguments.length?(e=Et(t),n):e},n};var zl=ta.map({circle:Oo,cross:function(n){var t=Math.sqrt(n/5)/2;return"M"+-3*t+","+-t+"H"+-t+"V"+-3*t+"H"+t+"V"+-t+"H"+3*t+"V"+t+"H"+t+"V"+3*t+"H"+-t+"V"+t+"H"+-3*t+"Z"},diamond:function(n){var t=Math.sqrt(n/(2*Ll)),e=t*Ll;return"M0,"+-t+"L"+e+",0 0,"+t+" "+-e+",0Z"},square:function(n){var t=Math.sqrt(n)/2;return"M"+-t+","+-t+"L"+t+","+-t+" "+t+","+t+" "+-t+","+t+"Z"},"triangle-down":function(n){var t=Math.sqrt(n/ql),e=t*ql/2;return"M0,"+e+"L"+t+","+-e+" "+-t+","+-e+"Z"},"triangle-up":function(n){var t=Math.sqrt(n/ql),e=t*ql/2;return"M0,"+-e+"L"+t+","+e+" "+-t+","+e+"Z"}});ta.svg.symbolTypes=zl.keys();var ql=Math.sqrt(3),Ll=Math.tan(30*Da);_a.transition=function(n){for(var t,e,r=Tl||++Ul,u=Xo(n),i=[],o=Rl||{time:Date.now(),ease:Su,delay:0,duration:250},a=-1,c=this.length;++ai;i++){u.push(t=[]);for(var e=this[i],a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return Yo(u,this.namespace,this.id)},Pl.tween=function(n,t){var e=this.id,r=this.namespace;return arguments.length<2?this.node()[r][e].tween.get(n):Y(this,null==t?function(t){t[r][e].tween.remove(n)}:function(u){u[r][e].tween.set(n,t)})},Pl.attr=function(n,t){function e(){this.removeAttribute(a)}function r(){this.removeAttributeNS(a.space,a.local)}function u(n){return null==n?e:(n+="",function(){var t,e=this.getAttribute(a);return e!==n&&(t=o(e,n),function(n){this.setAttribute(a,t(n))})})}function i(n){return null==n?r:(n+="",function(){var t,e=this.getAttributeNS(a.space,a.local);return e!==n&&(t=o(e,n),function(n){this.setAttributeNS(a.space,a.local,t(n))})})}if(arguments.length<2){for(t in n)this.attr(t,n[t]);return this}var o="transform"==n?Hu:mu,a=ta.ns.qualify(n);return Zo(this,"attr."+n,t,a.local?i:u)},Pl.attrTween=function(n,t){function e(n,e){var r=t.call(this,n,e,this.getAttribute(u));return r&&function(n){this.setAttribute(u,r(n))}}function r(n,e){var r=t.call(this,n,e,this.getAttributeNS(u.space,u.local));return r&&function(n){this.setAttributeNS(u.space,u.local,r(n))}}var u=ta.ns.qualify(n);return this.tween("attr."+n,u.local?r:e)},Pl.style=function(n,e,r){function u(){this.style.removeProperty(n)}function i(e){return null==e?u:(e+="",function(){var u,i=t(this).getComputedStyle(this,null).getPropertyValue(n);return i!==e&&(u=mu(i,e),function(t){this.style.setProperty(n,u(t),r)})})}var o=arguments.length;if(3>o){if("string"!=typeof n){2>o&&(e="");for(r in n)this.style(r,n[r],e);return this}r=""}return Zo(this,"style."+n,e,i)},Pl.styleTween=function(n,e,r){function u(u,i){var o=e.call(this,u,i,t(this).getComputedStyle(this,null).getPropertyValue(n));return o&&function(t){this.style.setProperty(n,o(t),r)}}return arguments.length<3&&(r=""),this.tween("style."+n,u)},Pl.text=function(n){return Zo(this,"text",n,Vo)},Pl.remove=function(){var n=this.namespace;return this.each("end.transition",function(){var t;this[n].count<2&&(t=this.parentNode)&&t.removeChild(this)})},Pl.ease=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].ease:("function"!=typeof n&&(n=ta.ease.apply(ta,arguments)),Y(this,function(r){r[e][t].ease=n}))},Pl.delay=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].delay:Y(this,"function"==typeof n?function(r,u,i){r[e][t].delay=+n.call(r,r.__data__,u,i)}:(n=+n,function(r){r[e][t].delay=n}))},Pl.duration=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].duration:Y(this,"function"==typeof n?function(r,u,i){r[e][t].duration=Math.max(1,n.call(r,r.__data__,u,i))}:(n=Math.max(1,n),function(r){r[e][t].duration=n}))},Pl.each=function(n,t){var e=this.id,r=this.namespace;if(arguments.length<2){var u=Rl,i=Tl;try{Tl=e,Y(this,function(t,u,i){Rl=t[r][e],n.call(t,t.__data__,u,i)})}finally{Rl=u,Tl=i}}else Y(this,function(u){var i=u[r][e];(i.event||(i.event=ta.dispatch("start","end","interrupt"))).on(n,t)});return this},Pl.transition=function(){for(var n,t,e,r,u=this.id,i=++Ul,o=this.namespace,a=[],c=0,l=this.length;l>c;c++){a.push(n=[]);for(var t=this[c],s=0,f=t.length;f>s;s++)(e=t[s])&&(r=e[o][u],$o(e,s,o,i,{time:r.time,ease:r.ease,delay:r.delay+r.duration,duration:r.duration})),n.push(e)}return Yo(a,o,i)},ta.svg.axis=function(){function n(n){n.each(function(){var n,l=ta.select(this),s=this.__chart__||e,f=this.__chart__=e.copy(),h=null==c?f.ticks?f.ticks.apply(f,a):f.domain():c,g=null==t?f.tickFormat?f.tickFormat.apply(f,a):y:t,p=l.selectAll(".tick").data(h,f),v=p.enter().insert("g",".domain").attr("class","tick").style("opacity",Ca),d=ta.transition(p.exit()).style("opacity",Ca).remove(),m=ta.transition(p.order()).style("opacity",1),M=Math.max(u,0)+o,x=Ui(f),b=l.selectAll(".domain").data([0]),_=(b.enter().append("path").attr("class","domain"),ta.transition(b));v.append("line"),v.append("text");var w,S,k,E,A=v.select("line"),N=m.select("line"),C=p.select("text").text(g),z=v.select("text"),q=m.select("text"),L="top"===r||"left"===r?-1:1;if("bottom"===r||"top"===r?(n=Bo,w="x",k="y",S="x2",E="y2",C.attr("dy",0>L?"0em":".71em").style("text-anchor","middle"),_.attr("d","M"+x[0]+","+L*i+"V0H"+x[1]+"V"+L*i)):(n=Wo,w="y",k="x",S="y2",E="x2",C.attr("dy",".32em").style("text-anchor",0>L?"end":"start"),_.attr("d","M"+L*i+","+x[0]+"H0V"+x[1]+"H"+L*i)),A.attr(E,L*u),z.attr(k,L*M),N.attr(S,0).attr(E,L*u),q.attr(w,0).attr(k,L*M),f.rangeBand){var T=f,R=T.rangeBand()/2;s=f=function(n){return T(n)+R}}else s.rangeBand?s=f:d.call(n,f,s);v.call(n,s,f),m.call(n,f,f)})}var t,e=ta.scale.linear(),r=jl,u=6,i=6,o=3,a=[10],c=null;return n.scale=function(t){return arguments.length?(e=t,n):e},n.orient=function(t){return arguments.length?(r=t in Fl?t+"":jl,n):r},n.ticks=function(){return arguments.length?(a=arguments,n):a},n.tickValues=function(t){return arguments.length?(c=t,n):c},n.tickFormat=function(e){return arguments.length?(t=e,n):t},n.tickSize=function(t){var e=arguments.length;return e?(u=+t,i=+arguments[e-1],n):u},n.innerTickSize=function(t){return arguments.length?(u=+t,n):u},n.outerTickSize=function(t){return arguments.length?(i=+t,n):i},n.tickPadding=function(t){return arguments.length?(o=+t,n):o},n.tickSubdivide=function(){return arguments.length&&n},n};var jl="bottom",Fl={top:1,right:1,bottom:1,left:1};ta.svg.brush=function(){function n(t){t.each(function(){var t=ta.select(this).style("pointer-events","all").style("-webkit-tap-highlight-color","rgba(0,0,0,0)").on("mousedown.brush",i).on("touchstart.brush",i),o=t.selectAll(".background").data([0]);o.enter().append("rect").attr("class","background").style("visibility","hidden").style("cursor","crosshair"),t.selectAll(".extent").data([0]).enter().append("rect").attr("class","extent").style("cursor","move");var a=t.selectAll(".resize").data(v,y);a.exit().remove(),a.enter().append("g").attr("class",function(n){return"resize "+n}).style("cursor",function(n){return Hl[n]}).append("rect").attr("x",function(n){return/[ew]$/.test(n)?-3:null}).attr("y",function(n){return/^[ns]/.test(n)?-3:null}).attr("width",6).attr("height",6).style("visibility","hidden"),a.style("display",n.empty()?"none":null);var c,f=ta.transition(t),h=ta.transition(o);l&&(c=Ui(l),h.attr("x",c[0]).attr("width",c[1]-c[0]),r(f)),s&&(c=Ui(s),h.attr("y",c[0]).attr("height",c[1]-c[0]),u(f)),e(f)})}function e(n){n.selectAll(".resize").attr("transform",function(n){return"translate("+f[+/e$/.test(n)]+","+h[+/^s/.test(n)]+")"})}function r(n){n.select(".extent").attr("x",f[0]),n.selectAll(".extent,.n>rect,.s>rect").attr("width",f[1]-f[0])}function u(n){n.select(".extent").attr("y",h[0]),n.selectAll(".extent,.e>rect,.w>rect").attr("height",h[1]-h[0])}function i(){function i(){32==ta.event.keyCode&&(C||(M=null,q[0]-=f[1],q[1]-=h[1],C=2),S())}function v(){32==ta.event.keyCode&&2==C&&(q[0]+=f[1],q[1]+=h[1],C=0,S())}function d(){var n=ta.mouse(b),t=!1;x&&(n[0]+=x[0],n[1]+=x[1]),C||(ta.event.altKey?(M||(M=[(f[0]+f[1])/2,(h[0]+h[1])/2]),q[0]=f[+(n[0]s?(u=r,r=s):u=s),v[0]!=r||v[1]!=u?(e?a=null:o=null,v[0]=r,v[1]=u,!0):void 0}function y(){d(),k.style("pointer-events","all").selectAll(".resize").style("display",n.empty()?"none":null),ta.select("body").style("cursor",null),L.on("mousemove.brush",null).on("mouseup.brush",null).on("touchmove.brush",null).on("touchend.brush",null).on("keydown.brush",null).on("keyup.brush",null),z(),w({type:"brushend"})}var M,x,b=this,_=ta.select(ta.event.target),w=c.of(b,arguments),k=ta.select(b),E=_.datum(),A=!/^(n|s)$/.test(E)&&l,N=!/^(e|w)$/.test(E)&&s,C=_.classed("extent"),z=W(b),q=ta.mouse(b),L=ta.select(t(b)).on("keydown.brush",i).on("keyup.brush",v);if(ta.event.changedTouches?L.on("touchmove.brush",d).on("touchend.brush",y):L.on("mousemove.brush",d).on("mouseup.brush",y),k.interrupt().selectAll("*").interrupt(),C)q[0]=f[0]-q[0],q[1]=h[0]-q[1];else if(E){var T=+/w$/.test(E),R=+/^n/.test(E);x=[f[1-T]-q[0],h[1-R]-q[1]],q[0]=f[T],q[1]=h[R]}else ta.event.altKey&&(M=q.slice());k.style("pointer-events","none").selectAll(".resize").style("display",null),ta.select("body").style("cursor",_.style("cursor")),w({type:"brushstart"}),d()}var o,a,c=E(n,"brushstart","brush","brushend"),l=null,s=null,f=[0,0],h=[0,0],g=!0,p=!0,v=Ol[0];return n.event=function(n){n.each(function(){var n=c.of(this,arguments),t={x:f,y:h,i:o,j:a},e=this.__chart__||t;this.__chart__=t,Tl?ta.select(this).transition().each("start.brush",function(){o=e.i,a=e.j,f=e.x,h=e.y,n({type:"brushstart"})}).tween("brush:brush",function(){var e=yu(f,t.x),r=yu(h,t.y);return o=a=null,function(u){f=t.x=e(u),h=t.y=r(u),n({type:"brush",mode:"resize"})}}).each("end.brush",function(){o=t.i,a=t.j,n({type:"brush",mode:"resize"}),n({type:"brushend"})}):(n({type:"brushstart"}),n({type:"brush",mode:"resize"}),n({type:"brushend"}))})},n.x=function(t){return arguments.length?(l=t,v=Ol[!l<<1|!s],n):l},n.y=function(t){return arguments.length?(s=t,v=Ol[!l<<1|!s],n):s},n.clamp=function(t){return arguments.length?(l&&s?(g=!!t[0],p=!!t[1]):l?g=!!t:s&&(p=!!t),n):l&&s?[g,p]:l?g:s?p:null},n.extent=function(t){var e,r,u,i,c;return arguments.length?(l&&(e=t[0],r=t[1],s&&(e=e[0],r=r[0]),o=[e,r],l.invert&&(e=l(e),r=l(r)),e>r&&(c=e,e=r,r=c),(e!=f[0]||r!=f[1])&&(f=[e,r])),s&&(u=t[0],i=t[1],l&&(u=u[1],i=i[1]),a=[u,i],s.invert&&(u=s(u),i=s(i)),u>i&&(c=u,u=i,i=c),(u!=h[0]||i!=h[1])&&(h=[u,i])),n):(l&&(o?(e=o[0],r=o[1]):(e=f[0],r=f[1],l.invert&&(e=l.invert(e),r=l.invert(r)),e>r&&(c=e,e=r,r=c))),s&&(a?(u=a[0],i=a[1]):(u=h[0],i=h[1],s.invert&&(u=s.invert(u),i=s.invert(i)),u>i&&(c=u,u=i,i=c))),l&&s?[[e,u],[r,i]]:l?[e,r]:s&&[u,i])},n.clear=function(){return n.empty()||(f=[0,0],h=[0,0],o=a=null),n},n.empty=function(){return!!l&&f[0]==f[1]||!!s&&h[0]==h[1]},ta.rebind(n,c,"on")};var Hl={n:"ns-resize",e:"ew-resize",s:"ns-resize",w:"ew-resize",nw:"nwse-resize",ne:"nesw-resize",se:"nwse-resize",sw:"nesw-resize"},Ol=[["n","e","s","w","nw","ne","se","sw"],["e","w"],["n","s"],[]],Il=ac.format=gc.timeFormat,Yl=Il.utc,Zl=Yl("%Y-%m-%dT%H:%M:%S.%LZ");Il.iso=Date.prototype.toISOString&&+new Date("2000-01-01T00:00:00.000Z")?Jo:Zl,Jo.parse=function(n){var t=new Date(n);return isNaN(t)?null:t},Jo.toString=Zl.toString,ac.second=Ft(function(n){return new cc(1e3*Math.floor(n/1e3))},function(n,t){n.setTime(n.getTime()+1e3*Math.floor(t))},function(n){return n.getSeconds()}),ac.seconds=ac.second.range,ac.seconds.utc=ac.second.utc.range,ac.minute=Ft(function(n){return new cc(6e4*Math.floor(n/6e4))},function(n,t){n.setTime(n.getTime()+6e4*Math.floor(t))},function(n){return n.getMinutes()}),ac.minutes=ac.minute.range,ac.minutes.utc=ac.minute.utc.range,ac.hour=Ft(function(n){var t=n.getTimezoneOffset()/60;return new cc(36e5*(Math.floor(n/36e5-t)+t))},function(n,t){n.setTime(n.getTime()+36e5*Math.floor(t))},function(n){return n.getHours()}),ac.hours=ac.hour.range,ac.hours.utc=ac.hour.utc.range,ac.month=Ft(function(n){return n=ac.day(n),n.setDate(1),n},function(n,t){n.setMonth(n.getMonth()+t)},function(n){return n.getMonth()}),ac.months=ac.month.range,ac.months.utc=ac.month.utc.range;var Vl=[1e3,5e3,15e3,3e4,6e4,3e5,9e5,18e5,36e5,108e5,216e5,432e5,864e5,1728e5,6048e5,2592e6,7776e6,31536e6],Xl=[[ac.second,1],[ac.second,5],[ac.second,15],[ac.second,30],[ac.minute,1],[ac.minute,5],[ac.minute,15],[ac.minute,30],[ac.hour,1],[ac.hour,3],[ac.hour,6],[ac.hour,12],[ac.day,1],[ac.day,2],[ac.week,1],[ac.month,1],[ac.month,3],[ac.year,1]],$l=Il.multi([[".%L",function(n){return n.getMilliseconds()}],[":%S",function(n){return n.getSeconds()}],["%I:%M",function(n){return n.getMinutes()}],["%I %p",function(n){return n.getHours()}],["%a %d",function(n){return n.getDay()&&1!=n.getDate()}],["%b %d",function(n){return 1!=n.getDate()}],["%B",function(n){return n.getMonth()}],["%Y",Ne]]),Bl={range:function(n,t,e){return ta.range(Math.ceil(n/e)*e,+t,e).map(Ko)},floor:y,ceil:y};Xl.year=ac.year,ac.scale=function(){return Go(ta.scale.linear(),Xl,$l)};var Wl=Xl.map(function(n){return[n[0].utc,n[1]]}),Jl=Yl.multi([[".%L",function(n){return n.getUTCMilliseconds()}],[":%S",function(n){return n.getUTCSeconds()}],["%I:%M",function(n){return n.getUTCMinutes()}],["%I %p",function(n){return n.getUTCHours()}],["%a %d",function(n){return n.getUTCDay()&&1!=n.getUTCDate()}],["%b %d",function(n){return 1!=n.getUTCDate()}],["%B",function(n){return n.getUTCMonth()}],["%Y",Ne]]);Wl.year=ac.year.utc,ac.scale.utc=function(){return Go(ta.scale.linear(),Wl,Jl)},ta.text=At(function(n){return n.responseText}),ta.json=function(n,t){return Nt(n,"application/json",Qo,t)},ta.html=function(n,t){return Nt(n,"text/html",na,t)},ta.xml=At(function(n){return n.responseXML}),"function"==typeof define&&define.amd?define(ta):"object"==typeof module&&module.exports&&(module.exports=ta),this.d3=ta}(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js new file mode 100644 index 0000000000000..6d2da25024a83 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -0,0 +1,29 @@ +/*v0.4.3 with 1 additional commit (see http://github.com/andrewor14/dagre-d3)*/(function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.dagreD3=f()}})(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;i0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph(); + +graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){arguments[4][20][0].apply(exports,arguments)},{dup:20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){ +var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){arguments[4][20][0].apply(exports,arguments)},{dup:20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments; + +stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments[1]:{},peg$FAILED={},peg$startRuleFunctions={start:peg$parsestart,graphStmt:peg$parsegraphStmt},peg$startRuleFunction=peg$parsestart,peg$c0=[],peg$c1=peg$FAILED,peg$c2=null,peg$c3="{",peg$c4={type:"literal",value:"{",description:'"{"'},peg$c5="}",peg$c6={type:"literal",value:"}",description:'"}"'},peg$c7=function(strict,type,id,stmts){return{type:type,id:id,strict:strict!==null,stmts:stmts}},peg$c8=";",peg$c9={type:"literal",value:";",description:'";"'},peg$c10=function(first,rest){var result=[first];for(var i=0;i",description:'"->"'},peg$c33=function(rhs,rest){var result=[rhs];if(rest){for(var i=0;ipos){peg$cachedPos=0;peg$cachedPosDetails={line:1,column:1,seenCR:false}}advance(peg$cachedPosDetails,peg$cachedPos,pos);peg$cachedPos=pos}return peg$cachedPosDetails}function peg$fail(expected){if(peg$currPospeg$maxFailPos){peg$maxFailPos=peg$currPos;peg$maxFailExpected=[]}peg$maxFailExpected.push(expected)}function peg$buildException(message,expected,pos){function cleanupExpected(expected){var i=1;expected.sort(function(a,b){if(a.descriptionb.description){return 1}else{return 0}});while(i1?expectedDescs.slice(0,-1).join(", ")+" or "+expectedDescs[expected.length-1]:expectedDescs[0];foundDesc=found?'"'+stringEscape(found)+'"':"end of input";return"Expected "+expectedDesc+" but "+foundDesc+" found."}var posDetails=peg$computePosDetails(pos),found=pospeg$currPos){s5=input.charAt(peg$currPos);peg$currPos++}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c110)}}if(s5!==peg$FAILED){s4=[s4,s5];s3=s4}else{peg$currPos=s3;s3=peg$c1}}else{peg$currPos=s3;s3=peg$c1}while(s3!==peg$FAILED){s2.push(s3);s3=peg$currPos;s4=peg$currPos;peg$silentFails++;if(input.substr(peg$currPos,2)===peg$c108){s5=peg$c108;peg$currPos+=2}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c109)}}peg$silentFails--;if(s5===peg$FAILED){s4=peg$c30}else{peg$currPos=s4;s4=peg$c1}if(s4!==peg$FAILED){if(input.length>peg$currPos){s5=input.charAt(peg$currPos);peg$currPos++}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c110)}}if(s5!==peg$FAILED){s4=[s4,s5];s3=s4}else{peg$currPos=s3;s3=peg$c1}}else{peg$currPos=s3;s3=peg$c1}}if(s2!==peg$FAILED){if(input.substr(peg$currPos,2)===peg$c108){s3=peg$c108;peg$currPos+=2}else{s3=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c109)}}if(s3!==peg$FAILED){s1=[s1,s2,s3];s0=s1}else{peg$currPos=s0;s0=peg$c1}}else{peg$currPos=s0;s0=peg$c1}}else{peg$currPos=s0;s0=peg$c1}}peg$silentFails--;if(s0===peg$FAILED){s1=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c101)}}return s0}function peg$parse_(){var s0;s0=peg$parsewhitespace();if(s0===peg$FAILED){s0=peg$parsecomment()}return s0}var _=require("lodash");var directed;peg$result=peg$startRuleFunction();if(peg$result!==peg$FAILED&&peg$currPos===input.length){return peg$result}else{if(peg$result!==peg$FAILED&&peg$currPos":"--",writer=new Writer;if(!g.isMultigraph()){writer.write("strict ")}writer.writeLine((g.isDirected()?"digraph":"graph")+" {");writer.indent();var graphAttrs=g.graph();if(_.isObject(graphAttrs)){_.each(graphAttrs,function(v,k){writer.writeLine(id(k)+"="+id(v)+";")})}writeSubgraph(g,undefined,writer);g.edges().forEach(function(edge){writeEdge(g,edge,ec,writer)});writer.unindent();writer.writeLine("}");return writer.toString()}function writeSubgraph(g,v,writer){var children=g.isCompound()?g.children(v):g.nodes();_.each(children,function(w){if(!g.isCompound()||!g.children(w).length){writeNode(g,w,writer)}else{writer.writeLine("subgraph "+id(w)+" {");writer.indent();if(_.isObject(g.node(w))){_.map(g.node(w),function(val,key){writer.writeLine(id(key)+"="+id(val)+";")})}writeSubgraph(g,w,writer);writer.unindent();writer.writeLine("}")}})}function writeNode(g,v,writer){writer.write(id(v));writeAttrs(g.node(v),writer);writer.writeLine()}function writeEdge(g,edge,ec,writer){var v=edge.v,w=edge.w,attrs=g.edge(edge);writer.write(id(v)+" "+ec+" "+id(w));writeAttrs(attrs,writer);writer.writeLine()}function writeAttrs(attrs,writer){if(_.isObject(attrs)){var attrStrs=_.map(attrs,function(val,key){return id(key)+"="+id(val)});if(attrStrs.length){writer.write(" ["+attrStrs.join(",")+"]")}}}function id(obj){if(typeof obj==="number"||obj.toString().match(UNESCAPED_ID_PATTERN)){return obj}return'"'+obj.toString().replace(/"/g,'\\"')+'"'}function Writer(){this._indent="";this._content="";this._shouldIndent=true}Writer.prototype.INDENT=" ";Writer.prototype.indent=function(){this._indent+=this.INDENT};Writer.prototype.unindent=function(){this._indent=this._indent.slice(this.INDENT.length)};Writer.prototype.writeLine=function(line){this.write((line||"")+"\n");this._shouldIndent=true};Writer.prototype.write=function(str){if(this._shouldIndent){this._shouldIndent=false;this._content+=this._indent}this._content+=str};Writer.prototype.toString=function(){return this._content}},{lodash:28}],9:[function(require,module,exports){var _=require("lodash");module.exports=_.clone(require("./lib"));module.exports.json=require("./lib/json");module.exports.alg=require("./lib/alg")},{"./lib":25,"./lib/alg":16,"./lib/json":26,lodash:28}],10:[function(require,module,exports){var _=require("lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{lodash:28}],11:[function(require,module,exports){var _=require("lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{lodash:28}],12:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"./dijkstra":13,lodash:28}],13:[function(require,module,exports){var _=require("lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":23,lodash:28}],14:[function(require,module,exports){var _=require("lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"./tarjan":21,lodash:28}],15:[function(require,module,exports){var _=require("lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":23,"../graph":24,lodash:28}],21:[function(require,module,exports){var _=require("lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{lodash:28}],22:[function(require,module,exports){var _=require("lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{lodash:28}],23:[function(require,module,exports){var _=require("lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(v,w,value,name){var valueSpecified=arguments.length>2;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{lodash:28}],25:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":24,"./version":27}],26:[function(require,module,exports){var _=require("lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":24,lodash:28}],27:[function(require,module,exports){module.exports="0.8.1"},{}],28:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f "+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index + * div.stage-metadata > + * div.[dot-file | incoming-edge | outgoing-edge] + * + * Output DOM hierarchy: + * div#dag-viz-graph > + * svg > + * g#cluster_stage_[stageId] + * + * Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz. + * Any changes in the input format here must be reflected there. + */ +function renderDagViz(forJob) { + + // If there is not a dot file to render, fail fast and report error + if (metadataContainer().empty()) { + graphContainer().append("div").text( + "No visualization information available for this " + (forJob ? "job" : "stage")); + return; + } + + var svg = graphContainer().append("svg"); + if (forJob) { + renderDagVizForJob(svg); + } else { + renderDagVizForStage(svg); + } + + // Find cached RDDs + metadataContainer().selectAll(".cached-rdd").each(function(v) { + var nodeId = VizConstants.nodePrefix + d3.select(this).text(); + graphContainer().selectAll("#" + nodeId).classed("cached", true); + }); + + // Set the appropriate SVG dimensions to ensure that all elements are displayed + var boundingBox = svg.node().getBBox(); + svg.style("width", (boundingBox.width + VizConstants.svgMarginX) + "px"); + svg.style("height", (boundingBox.height + VizConstants.svgMarginY) + "px"); + + // Add labels to clusters because dagre-d3 doesn't do this for us + svg.selectAll("g.cluster rect").each(function() { + var rect = d3.select(this); + var cluster = d3.select(this.parentNode); + // Shift the boxes up a little to make room for the labels + rect.attr("y", toFloat(rect.attr("y")) - 10); + rect.attr("height", toFloat(rect.attr("height")) + 10); + var labelX = toFloat(rect.attr("x")) + toFloat(rect.attr("width")) - 5; + var labelY = toFloat(rect.attr("y")) + 15; + var labelText = cluster.attr("name").replace(VizConstants.clusterPrefix, ""); + cluster.append("text") + .attr("x", labelX) + .attr("y", labelY) + .attr("text-anchor", "end") + .text(labelText); + }); + + // We have shifted a few elements upwards, so we should fix the SVG views + var startX = -VizConstants.svgMarginX; + var startY = -VizConstants.svgMarginY; + var endX = toFloat(svg.style("width")) + VizConstants.svgMarginX; + var endY = toFloat(svg.style("height")) + VizConstants.svgMarginY; + var newViewBox = startX + " " + startY + " " + endX + " " + endY; + svg.attr("viewBox", newViewBox); + + // Lastly, apply some custom style to the DAG + styleDagViz(forJob); +} + +/* Render the RDD DAG visualization for a stage. */ +function renderDagVizForStage(svgContainer) { + var metadata = metadataContainer().select(".stage-metadata"); + var dot = metadata.select(".dot-file").text(); + var containerId = VizConstants.graphPrefix + metadata.attr("stageId"); + var container = svgContainer.append("g").attr("id", containerId); + renderDot(dot, container); +} + +/* + * Render the RDD DAG visualization for a job. + * + * Due to limitations in dagre-d3, each stage is rendered independently so that + * we have more control on how to position them. Unfortunately, this means we + * cannot rely on dagre-d3 to render edges that cross stages and must render + * these manually on our own. + */ +function renderDagVizForJob(svgContainer) { + var crossStageEdges = []; + + metadataContainer().selectAll(".stage-metadata").each(function(d, i) { + var metadata = d3.select(this); + var dot = metadata.select(".dot-file").text(); + var stageId = metadata.attr("stageId"); + var containerId = VizConstants.graphPrefix + stageId; + // TODO: handle stage attempts + var stageLink = + "/stages/stage/?id=" + stageId.replace(VizConstants.stagePrefix, "") + "&attempt=0"; + var container = svgContainer + .append("a").attr("xlink:href", stageLink) + .append("g").attr("id", containerId); + // Now we need to shift the container for this stage so it doesn't overlap + // with existing ones. We do not need to do this for the first stage. + if (i > 0) { + // Take into account the position and width of the last stage's container + var existingStages = stageClusters(); + if (!existingStages.empty()) { + var lastStage = existingStages[0].pop(); + var lastStageId = d3.select(lastStage).attr("id"); + var lastStageWidth = toFloat(d3.select("#" + lastStageId + " rect").attr("width")); + var lastStagePosition = getAbsolutePosition(lastStageId); + var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep; + container.attr("transform", "translate(" + offset + ", 0)"); + } + } + renderDot(dot, container); + // If there are any incoming edges into this graph, keep track of them to render + // them separately later. Note that we cannot draw them now because we need to + // put these edges in a separate container that is on top of all stage graphs. + metadata.selectAll(".incoming-edge").each(function(v) { + var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + crossStageEdges.push(edge); + }); + }); + + // Draw edges that cross stages + if (crossStageEdges.length > 0) { + var container = svgContainer.append("g").attr("id", "cross-stage-edges"); + for (var i = 0; i < crossStageEdges.length; i++) { + var fromRDDId = crossStageEdges[i][0]; + var toRDDId = crossStageEdges[i][1]; + connectRDDs(fromRDDId, toRDDId, container); + } + } +} + +/* Render the dot file as an SVG in the given container. */ +function renderDot(dot, container) { + var escaped_dot = dot + .replace(/</g, "<") + .replace(/>/g, ">") + .replace(/"/g, "\""); + var g = graphlibDot.read(escaped_dot); + var renderer = new dagreD3.render(); + renderer(container, g); +} + +/* Style the visualization we just rendered. */ +function styleDagViz(forJob) { + graphContainer().selectAll("svg g.cluster rect") + .style("fill", "white") + .style("stroke", VizConstants.rddOperationColor) + .style("stroke-width", "4px") + .style("stroke-opacity", "0.5"); + graphContainer().selectAll("svg g.cluster text") + .attr("fill", VizConstants.clusterLabelColor) + .attr("font-size", "11px"); + graphContainer().selectAll("svg path") + .style("stroke", VizConstants.edgeColor) + .style("stroke-width", VizConstants.edgeWidth); + stageClusters() + .select("rect") + .style("stroke", VizConstants.stageColor) + .style("strokeWidth", "6px"); + + // Put an arrow at the end of every edge + // We need to do this because we manually render some edges ourselves + // For these edges, we borrow the arrow marker generated by dagre-d3 + var dagreD3Marker = graphContainer().select("svg g.edgePaths marker").node(); + graphContainer().select("svg") + .append(function() { return dagreD3Marker.cloneNode(true); }) + .attr("id", "marker-arrow") + .select("path") + .attr("fill", VizConstants.edgeColor) + .attr("strokeWidth", "0px"); + graphContainer().selectAll("svg g > path").attr("marker-end", "url(#marker-arrow)"); + graphContainer().selectAll("svg g.edgePaths def").remove(); // We no longer need these + + // Apply any job or stage specific styles + if (forJob) { + styleDagVizForJob(); + } else { + styleDagVizForStage(); + } +} + +/* Apply job-page-specific style to the visualization. */ +function styleDagVizForJob() { + graphContainer().selectAll("svg g.node circle") + .style("fill", VizConstants.rddColor); + // TODO: add a legend to explain what a highlighted dot means + graphContainer().selectAll("svg g.cached circle") + .style("fill", VizConstants.rddCachedColor); + graphContainer().selectAll("svg g#cross-stage-edges path") + .style("fill", "none"); +} + +/* Apply stage-page-specific style to the visualization. */ +function styleDagVizForStage() { + graphContainer().selectAll("svg g.node rect") + .style("fill", "none") + .style("stroke", VizConstants.rddColor) + .style("stroke-width", "2px") + .attr("rx", "5") // round corners + .attr("ry", "5"); + // TODO: add a legend to explain what a highlighted RDD means + graphContainer().selectAll("svg g.cached rect") + .style("stroke", VizConstants.rddCachedColor); + graphContainer().selectAll("svg g.node g.label text tspan") + .style("fill", VizConstants.rddColor); +} + +/* + * (Job page only) Helper method to compute the absolute + * position of the group element identified by the given ID. + */ +function getAbsolutePosition(groupId) { + var obj = d3.select("#" + groupId).filter("g"); + var _x = 0, _y = 0; + while (!obj.empty()) { + var transformText = obj.attr("transform"); + var translate = d3.transform(transformText).translate + _x += translate[0]; + _y += translate[1]; + obj = d3.select(obj.node().parentNode).filter("g") + } + return { x: _x, y: _y }; +} + +/* (Job page only) Connect two RDD nodes with a curved edge. */ +function connectRDDs(fromRDDId, toRDDId, container) { + var fromNodeId = VizConstants.nodePrefix + fromRDDId; + var toNodeId = VizConstants.nodePrefix + toRDDId + var fromPos = getAbsolutePosition(fromNodeId); + var toPos = getAbsolutePosition(toNodeId); + + // On the job page, RDDs are rendered as dots (circles). When rendering the path, + // we need to account for the radii of these circles. Otherwise the arrow heads + // will bleed into the circle itself. + var delta = toFloat(graphContainer() + .select("g.node#" + toNodeId) + .select("circle") + .attr("r")); + if (fromPos.x < toPos.x) { + fromPos.x += delta; + toPos.x -= delta; + } else if (fromPos.x > toPos.x) { + fromPos.x -= delta; + toPos.x += delta; + } + + if (fromPos.y == toPos.y) { + // If they are on the same rank, curve the middle part of the edge + // upward a little to avoid interference with things in between + // e.g. _______ + // _____/ \_____ + var points = [ + [fromPos.x, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.2, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.3, fromPos.y - 20], + [fromPos.x + (toPos.x - fromPos.x) * 0.7, fromPos.y - 20], + [fromPos.x + (toPos.x - fromPos.x) * 0.8, toPos.y], + [toPos.x, toPos.y] + ]; + } else { + // Otherwise, draw a curved edge that flattens out on both ends + // e.g. _____ + // / + // | + // _____/ + var points = [ + [fromPos.x, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.4, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.6, toPos.y], + [toPos.x, toPos.y] + ]; + } + + var line = d3.svg.line().interpolate("basis"); + container.append("path").datum(points).attr("d", line); +} + +/* Helper d3 accessor to clusters that represent stages. */ +function stageClusters() { + return graphContainer().selectAll("g.cluster").filter(function() { + return d3.select(this).attr("id").indexOf(VizConstants.stageClusterPrefix) > -1; + }); +} + +/* Helper method to convert attributes to numeric values. */ +function toFloat(f) { + return parseFloat(f.replace(/px$/, "")); +} + diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 4910744d1d790..669ad48937c05 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -145,7 +145,7 @@ pre { border: none; } -span.expand-additional-metrics { +span.expand-additional-metrics, span.expand-dag-viz { cursor: pointer; } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4ef90546a2452..b98a54b418cd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -659,6 +659,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } + /** + * Execute a block of code in a scope such that all new RDDs created in this body will + * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. + * + * Note: Return statements are NOT allowed in the given body. + */ + private def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -669,7 +677,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. */ - def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def parallelize[T: ClassTag]( + seq: Seq[T], + numSlices: Int = defaultParallelism): RDD[T] = withScope { assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -678,14 +688,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * This method is identical to `parallelize`. */ - def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def makeRDD[T: ClassTag]( + seq: Seq[T], + numSlices: Int = defaultParallelism): RDD[T] = withScope { parallelize(seq, numSlices) } /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) @@ -695,10 +707,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. */ - def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + def textFile( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[String] = withScope { assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], - minPartitions).map(pair => pair._2.toString).setName(path) + minPartitions).map(pair => pair._2.toString) } /** @@ -728,8 +742,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): - RDD[(String, String)] = { + def wholeTextFiles( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = withScope { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking @@ -776,8 +791,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred; very large files may cause bad performance. */ @Experimental - def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): - RDD[(String, PortableDataStream)] = { + def binaryFiles( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking @@ -806,8 +822,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @return An RDD of data with values, represented as byte arrays */ @Experimental - def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) - : RDD[Array[Byte]] = { + def binaryRecords( + path: String, + recordLength: Int, + conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = withScope { assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, @@ -848,8 +866,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minPartitions: Int = defaultMinPartitions - ): RDD[(K, V)] = { + minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) @@ -869,8 +886,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minPartitions: Int = defaultMinPartitions - ): RDD[(K, V)] = { + minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) @@ -901,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile(path, fm.runtimeClass.asInstanceOf[Class[F]], km.runtimeClass.asInstanceOf[Class[K]], @@ -924,13 +940,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile[K, V, F](path, defaultMinPartitions) + } /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] (path: String) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { newAPIHadoopFile( path, fm.runtimeClass.asInstanceOf[Class[F]], @@ -953,7 +970,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() // The call to new NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves @@ -987,7 +1004,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { + vClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() // Add necessary security credentials to the JobConf. Required to access secure HDFS. val jconf = new JobConf(conf) @@ -1007,7 +1024,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli keyClass: Class[K], valueClass: Class[V], minPartitions: Int - ): RDD[(K, V)] = { + ): RDD[(K, V)] = withScope { assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) @@ -1021,7 +1038,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + def sequenceFile[K, V]( + path: String, + keyClass: Class[K], + valueClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) } @@ -1051,16 +1071,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) (implicit km: ClassTag[K], vm: ClassTag[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) - : RDD[(K, V)] = { - assertNotStopped() - val kc = kcf() - val vc = vcf() - val format = classOf[SequenceFileInputFormat[Writable, Writable]] - val writables = hadoopFile(path, format, + kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { + withScope { + assertNotStopped() + val kc = kcf() + val vc = vcf() + val format = classOf[SequenceFileInputFormat[Writable, Writable]] + val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], vc.writableClass(vm).asInstanceOf[Class[Writable]], minPartitions) - writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } + writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } + } } /** @@ -1073,21 +1094,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def objectFile[T: ClassTag]( path: String, - minPartitions: Int = defaultMinPartitions - ): RDD[T] = { + minPartitions: Int = defaultMinPartitions): RDD[T] = withScope { assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } - protected[spark] def checkpointFile[T: ClassTag]( - path: String - ): RDD[T] = { + protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope { new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ - def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { val partitioners = rdds.flatMap(_.partitioner).toSet if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { new PartitionerAwareUnionRDD(this, rdds) @@ -1097,8 +1115,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = + def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = withScope { union(Seq(first) ++ rest) + } /** Get an RDD that has no partitions or elements. */ def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) @@ -2060,10 +2079,10 @@ object SparkContext extends Logging { } private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" - private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" - private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" + private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" + private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" /** * Executor id for the driver. In earlier versions of Spark, this was ``, but this was diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 3406a7e97e368..ec185340c3a2d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -33,7 +33,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Returns a future for counting the number of elements in the RDD. */ - def countAsync(): FutureAction[Long] = { + def countAsync(): FutureAction[Long] = self.withScope { val totalCount = new AtomicLong self.context.submitJob( self, @@ -53,7 +53,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Returns a future for retrieving all elements of this RDD. */ - def collectAsync(): FutureAction[Seq[T]] = { + def collectAsync(): FutureAction[Seq[T]] = self.withScope { val results = new Array[Array[T]](self.partitions.length) self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) @@ -62,7 +62,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Returns a future for retrieving the first num elements of the RDD. */ - def takeAsync(num: Int): FutureAction[Seq[T]] = { + def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { val f = new ComplexFutureAction[Seq[T]] f.run { @@ -109,7 +109,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(f: T => Unit): FutureAction[Unit] = { + def foreachAsync(f: T => Unit): FutureAction[Unit] = self.withScope { val cleanF = self.context.clean(f) self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) @@ -118,7 +118,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = self.withScope { self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 843a893235e56..926bce6f15a2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.StatCounter */ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Add up the elements in this RDD. */ - def sum(): Double = { + def sum(): Double = self.withScope { self.fold(0.0)(_ + _) } @@ -38,37 +38,49 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. */ - def stats(): StatCounter = { + def stats(): StatCounter = self.withScope { self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) } /** Compute the mean of this RDD's elements. */ - def mean(): Double = stats().mean + def mean(): Double = self.withScope { + stats().mean + } /** Compute the variance of this RDD's elements. */ - def variance(): Double = stats().variance + def variance(): Double = self.withScope { + stats().variance + } /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = stats().stdev + def stdev(): Double = self.withScope { + stats().stdev + } /** * Compute the sample standard deviation of this RDD's elements (which corrects for bias in * estimating the standard deviation by dividing by N-1 instead of N). */ - def sampleStdev(): Double = stats().sampleStdev + def sampleStdev(): Double = self.withScope { + stats().sampleStdev + } /** * Compute the sample variance of this RDD's elements (which corrects for bias in * estimating the variance by dividing by N-1 instead of N). */ - def sampleVariance(): Double = stats().sampleVariance + def sampleVariance(): Double = self.withScope { + stats().sampleVariance + } /** * :: Experimental :: * Approximate operation to return the mean within a timeout. */ @Experimental - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def meanApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) @@ -79,7 +91,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * Approximate operation to return the sum within a timeout. */ @Experimental - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def sumApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) @@ -93,7 +107,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope { // Scala's built-in range has issues. See #SI-8782 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { val span = max - min @@ -140,7 +154,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * the maximum value of the last position and all NaN entries will be counted * in that bucket. */ - def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { + def histogram( + buckets: Array[Double], + evenBuckets: Boolean = false): Array[Long] = self.withScope { if (buckets.length < 2) { throw new IllegalArgumentException("buckets array must have at least two elements") } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f77abac42b623..2cefe63d44b20 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - sc: SparkContext, + @transient sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -108,6 +108,10 @@ class HadoopRDD[K, V]( minPartitions: Int) extends RDD[(K, V)](sc, Nil) with Logging { + if (initLocalJobConfFuncOpt.isDefined) { + sc.clean(initLocalJobConfFuncOpt.get) + } + def this( sc: SparkContext, conf: JobConf, diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6afe50161dacd..d71bb63000904 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, */ // TODO: this currently doesn't work on P other than Tuple2! def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) - : RDD[(K, V)] = + : RDD[(K, V)] = self.withScope { val part = new RangePartitioner(numPartitions, self, ascending) new ShuffledRDD[K, V, V](self, part) @@ -71,7 +71,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * This is more efficient than calling `repartition` and then sorting within each partition * because it can push the sorting down into the shuffle machinery. */ - def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = { + def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope { new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) } @@ -81,7 +81,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * performed efficiently by only scanning the partitions that might contain matching elements. * Otherwise, a standard `filter` is applied to all partitions. */ - def filterByRange(lower: K, upper: K): RDD[P] = { + def filterByRange(lower: K, upper: K): RDD[P] = self.withScope { def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 05351ba4ff76b..93d338fe0530c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -29,7 +29,7 @@ import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -75,7 +75,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = { + serializer: Serializer = null): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -108,7 +108,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - numPartitions: Int): RDD[(K, C)] = { + numPartitions: Int): RDD[(K, C)] = self.withScope { combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } @@ -122,7 +122,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) val zeroArray = new Array[Byte](zeroBuffer.limit) @@ -144,7 +144,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) } @@ -158,7 +158,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) } @@ -167,7 +167,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey( + zeroValue: V, + partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) val zeroArray = new Array[Byte](zeroBuffer.limit) @@ -185,7 +187,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = self.withScope { foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) } @@ -194,7 +196,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = self.withScope { foldByKey(zeroValue, defaultPartitioner(self))(func) } @@ -213,7 +215,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -242,9 +244,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @return RDD containing the sampled subset */ @Experimental - def sampleByKeyExact(withReplacement: Boolean, + def sampleByKeyExact( + withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -261,7 +264,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. */ - def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { + def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKey[V]((v: V) => v, func, func, partitioner) } @@ -270,7 +273,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = self.withScope { reduceByKey(new HashPartitioner(numPartitions), func) } @@ -280,7 +283,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ * parallelism level. */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope { reduceByKey(defaultPartitioner(self), func) } @@ -289,7 +292,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * immediately to the master as a Map. This will also perform the merging locally on each mapper * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -317,7 +320,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** Alias for reduceByKeyLocally */ @deprecated("Use reduceByKeyLocally", "1.0.0") - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope { + reduceByKeyLocally(func) + } /** * Count the number of elements for each key, collecting the results to a local Map. @@ -327,7 +332,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ - def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap + def countByKey(): Map[K, Long] = self.withScope { + self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap + } /** * :: Experimental :: @@ -336,7 +343,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { + : PartialResult[Map[K, BoundedDouble]] = self.withScope { self.map(_._1).countByValueApprox(timeout, confidence) } @@ -360,7 +367,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param partitioner Partitioner to use for the resulting RDD. */ @Experimental - def countApproxDistinctByKey(p: Int, sp: Int, partitioner: Partitioner): RDD[(K, Long)] = { + def countApproxDistinctByKey( + p: Int, + sp: Int, + partitioner: Partitioner): RDD[(K, Long)] = self.withScope { require(p >= 4, s"p ($p) must be >= 4") require(sp <= 32, s"sp ($sp) must be <= 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") @@ -392,7 +402,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * It must be greater than 0.000017. * @param partitioner partitioner of the resulting RDD */ - def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = { + def countApproxDistinctByKey( + relativeSD: Double, + partitioner: Partitioner): RDD[(K, Long)] = self.withScope { require(relativeSD > 0.000017, s"accuracy ($relativeSD) must be greater than 0.000017") val p = math.ceil(2.0 * math.log(1.054 / relativeSD) / math.log(2)).toInt assert(p <= 32) @@ -410,7 +422,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * It must be greater than 0.000017. * @param numPartitions number of partitions of the resulting RDD */ - def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): RDD[(K, Long)] = { + def countApproxDistinctByKey( + relativeSD: Double, + numPartitions: Int): RDD[(K, Long)] = self.withScope { countApproxDistinctByKey(relativeSD, new HashPartitioner(numPartitions)) } @@ -424,7 +438,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = { + def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = self.withScope { countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } @@ -441,7 +455,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ - def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = { + def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. @@ -465,14 +479,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ - def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = { + def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { groupByKey(new HashPartitioner(numPartitions)) } /** * Return a copy of the RDD partitioned using the specified partitioner. */ - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { + def partitionBy(partitioner: Partitioner): RDD[(K, V)] = self.withScope { if (keyClass.isArray && partitioner.isInstanceOf[HashPartitioner]) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -488,7 +502,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ - def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues( pair => for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w) ) @@ -500,7 +514,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to * partition the output RDD. */ - def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W]( + other: RDD[(K, W)], + partitioner: Partitioner): RDD[(K, (V, Option[W]))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { pair => if (pair._2.isEmpty) { pair._1.iterator.map(v => (v, None)) @@ -517,7 +533,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * partition the output RDD. */ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], W))] = { + : RDD[(K, (Option[V], W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { pair => if (pair._1.isEmpty) { pair._2.iterator.map(w => (None, w)) @@ -536,7 +552,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * in `this` have key k. Uses the given Partitioner to partition the output RDD. */ def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], Option[W]))] = { + : RDD[(K, (Option[V], Option[W]))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { case (vs, Seq()) => vs.iterator.map(v => (Some(v), None)) case (Seq(), ws) => ws.iterator.map(w => (None, Some(w))) @@ -549,7 +565,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { + : RDD[(K, C)] = self.withScope { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } @@ -563,7 +579,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupByKey(): RDD[(K, Iterable[V])] = { + def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { groupByKey(defaultPartitioner(self)) } @@ -572,7 +588,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope { join(other, defaultPartitioner(self, other)) } @@ -581,7 +597,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = self.withScope { join(other, new HashPartitioner(numPartitions)) } @@ -591,7 +607,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output * using the existing partitioner/parallelism level. */ - def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = self.withScope { leftOuterJoin(other, defaultPartitioner(self, other)) } @@ -601,7 +617,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output * into `numPartitions` partitions. */ - def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (V, Option[W]))] = self.withScope { leftOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -611,7 +629,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD using the existing partitioner/parallelism level. */ - def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { + def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = self.withScope { rightOuterJoin(other, defaultPartitioner(self, other)) } @@ -621,7 +639,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD into the given number of partitions. */ - def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { + def rightOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Option[V], W))] = self.withScope { rightOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -634,7 +654,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ * parallelism level. */ - def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = self.withScope { fullOuterJoin(other, defaultPartitioner(self, other)) } @@ -646,7 +666,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. */ - def fullOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = self.withScope { fullOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -656,7 +678,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only * one value per key is preserved in the map returned) */ - def collectAsMap(): Map[K, V] = { + def collectAsMap(): Map[K, V] = self.withScope { val data = self.collect() val map = new mutable.HashMap[K, V] map.sizeHint(data.length) @@ -668,7 +690,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Pass each value in the key-value pair RDD through a map function without changing the keys; * this also retains the original RDD's partitioning. */ - def mapValues[U](f: V => U): RDD[(K, U)] = { + def mapValues[U](f: V => U): RDD[(K, U)] = self.withScope { val cleanF = self.context.clean(f) new MapPartitionsRDD[(K, U), (K, V)](self, (context, pid, iter) => iter.map { case (k, v) => (k, cleanF(v)) }, @@ -679,7 +701,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Pass each value in the key-value pair RDD through a flatMap function without changing the * keys; this also retains the original RDD's partitioning. */ - def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { + def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = self.withScope { val cleanF = self.context.clean(f) new MapPartitionsRDD[(K, U), (K, V)](self, (context, pid, iter) => iter.flatMap { case (k, v) => @@ -697,7 +719,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) other2: RDD[(K, W2)], other3: RDD[(K, W3)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -715,7 +737,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W]))] = { + : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -730,7 +752,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -748,7 +770,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * for that key in `this`, `other1`, `other2` and `other3`. */ def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) } @@ -756,7 +778,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, defaultPartitioner(self, other)) } @@ -765,7 +787,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } @@ -773,7 +795,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, new HashPartitioner(numPartitions)) } @@ -782,7 +806,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, new HashPartitioner(numPartitions)) } @@ -795,24 +819,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) other2: RDD[(K, W2)], other3: RDD[(K, W3)], numPartitions: Int) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) } /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { + def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, defaultPartitioner(self, other)) } /** Alias for cogroup. */ def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } /** Alias for cogroup. */ def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) } @@ -822,22 +846,27 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope { subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) + } /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + def subtractByKey[W: ClassTag]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, V)] = self.withScope { subtractByKey(other, new HashPartitioner(numPartitions)) + } /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope { new SubtractedRDD[K, V, W](self, other, p) + } /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): Seq[V] = { + def lookup(key: K): Seq[V] = self.withScope { self.partitioner match { case Some(p) => val index = p.getPartition(key) @@ -859,7 +888,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String)(implicit fm: ClassTag[F]): Unit = self.withScope { saveAsHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -869,7 +899,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * supplied codec. */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { + path: String, + codec: Class[_ <: CompressionCodec])(implicit fm: ClassTag[F]): Unit = self.withScope { val runtimeClass = fm.runtimeClass saveAsHadoopFile(path, keyClass, valueClass, runtimeClass.asInstanceOf[Class[F]], codec) } @@ -878,7 +909,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]]( + path: String)(implicit fm: ClassTag[F]): Unit = self.withScope { saveAsNewAPIHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -891,8 +923,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) - { + conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = new NewAPIHadoopJob(hadoopConf) @@ -912,7 +943,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - codec: Class[_ <: CompressionCodec]) { + codec: Class[_ <: CompressionCodec]): Unit = self.withScope { saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, new JobConf(self.context.hadoopConfiguration), Some(codec)) } @@ -927,7 +958,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(self.context.hadoopConfiguration), - codec: Option[Class[_ <: CompressionCodec]] = None) { + codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) @@ -960,7 +991,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. */ - def saveAsNewAPIHadoopDataset(conf: Configuration) { + def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = new NewAPIHadoopJob(hadoopConf) @@ -1027,7 +1058,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop * MapReduce job. */ - def saveAsHadoopDataset(conf: JobConf) { + def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val wrappedConf = new SerializableWritable(hadoopConf) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 31c07c73fe07b..7f7c7ed144eb3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -25,7 +25,7 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text} +import org.apache.hadoop.io.{BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.TextOutputFormat @@ -277,12 +277,20 @@ abstract class RDD[T: ClassTag]( if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) } + /** + * Execute a block of code in a scope such that all new RDDs created in this body will + * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](sc)(body) + // Transformations (return a new RDD) /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[U: ClassTag](f: T => U): RDD[U] = { + def map[U: ClassTag](f: T => U): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF)) } @@ -291,7 +299,7 @@ abstract class RDD[T: ClassTag]( * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = { + def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF)) } @@ -299,7 +307,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing only the elements that satisfy a predicate. */ - def filter(f: T => Boolean): RDD[T] = { + def filter(f: T => Boolean): RDD[T] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[T, T]( this, @@ -310,13 +318,16 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = + def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + } /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.length) + def distinct(): RDD[T] = withScope { + distinct(partitions.length) + } /** * Return a new RDD that has exactly numPartitions partitions. @@ -327,7 +338,7 @@ abstract class RDD[T: ClassTag]( * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. */ - def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = { + def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) } @@ -352,7 +363,7 @@ abstract class RDD[T: ClassTag]( * data distributed using a hash partitioner. */ def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null) - : RDD[T] = { + : RDD[T] = withScope { if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { @@ -377,16 +388,17 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator */ - def sample(withReplacement: Boolean, + def sample( + withReplacement: Boolean, fraction: Double, - seed: Long = Utils.random.nextLong): RDD[T] = { + seed: Long = Utils.random.nextLong): RDD[T] = withScope { require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) @@ -403,7 +415,9 @@ abstract class RDD[T: ClassTag]( * * @return split RDDs in an array */ - def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = { + def randomSplit( + weights: Array[Double], + seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope { val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => @@ -435,7 +449,9 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - def takeSample(withReplacement: Boolean, + // TODO: rewrite this without return statements so we can wrap it in a scope + def takeSample( + withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { val numStDev = 10.0 @@ -483,7 +499,7 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def union(other: RDD[T]): RDD[T] = { + def union(other: RDD[T]): RDD[T] = withScope { if (partitioner.isDefined && other.partitioner == partitioner) { new PartitionerAwareUnionRDD(sc, Array(this, other)) } else { @@ -495,7 +511,9 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def ++(other: RDD[T]): RDD[T] = this.union(other) + def ++(other: RDD[T]): RDD[T] = withScope { + this.union(other) + } /** * Return this RDD sorted by the given key function. @@ -504,10 +522,11 @@ abstract class RDD[T: ClassTag]( f: (T) => K, ascending: Boolean = true, numPartitions: Int = this.partitions.length) - (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = + (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = withScope { this.keyBy[K](f) .sortByKey(ascending, numPartitions) .values + } /** * Return the intersection of this RDD and another one. The output will not contain any duplicate @@ -515,7 +534,7 @@ abstract class RDD[T: ClassTag]( * * Note that this method performs a shuffle internally. */ - def intersection(other: RDD[T]): RDD[T] = { + def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } .keys @@ -529,8 +548,9 @@ abstract class RDD[T: ClassTag]( * * @param partitioner Partitioner to use for the resulting RDD */ - def intersection(other: RDD[T], partitioner: Partitioner)(implicit ord: Ordering[T] = null) - : RDD[T] = { + def intersection( + other: RDD[T], + partitioner: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner) .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } .keys @@ -544,16 +564,14 @@ abstract class RDD[T: ClassTag]( * * @param numPartitions How many partitions to use in the resulting RDD */ - def intersection(other: RDD[T], numPartitions: Int): RDD[T] = { - this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions)) - .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } - .keys + def intersection(other: RDD[T], numPartitions: Int): RDD[T] = withScope { + intersection(other, new HashPartitioner(numPartitions)) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ - def glom(): RDD[Array[T]] = { + def glom(): RDD[Array[T]] = withScope { new MapPartitionsRDD[Array[T], T](this, (context, pid, iter) => Iterator(iter.toArray)) } @@ -561,7 +579,9 @@ abstract class RDD[T: ClassTag]( * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of * elements (a, b) where a is in `this` and b is in `other`. */ - def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope { + new CartesianRDD(sc, this, other) + } /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements @@ -572,8 +592,9 @@ abstract class RDD[T: ClassTag]( * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = + def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) + } /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements @@ -584,8 +605,11 @@ abstract class RDD[T: ClassTag]( * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = + def groupBy[K]( + f: T => K, + numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy(f, new HashPartitioner(numPartitions)) + } /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements @@ -597,7 +621,7 @@ abstract class RDD[T: ClassTag]( * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) - : RDD[(K, Iterable[T])] = { + : RDD[(K, Iterable[T])] = withScope { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(p) } @@ -605,13 +629,16 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String): RDD[String] = new PipedRDD(this, command) + def pipe(command: String): RDD[String] = withScope { + new PipedRDD(this, command) + } /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String, env: Map[String, String]): RDD[String] = + def pipe(command: String, env: Map[String, String]): RDD[String] = withScope { new PipedRDD(this, command, env) + } /** * Return an RDD created by piping elements to a forked external process. @@ -619,7 +646,7 @@ abstract class RDD[T: ClassTag]( * * @param command command to run in forked process. * @param env environment variables to set. - * @param printPipeContext Before piping elements, this function is called as an oppotunity + * @param printPipeContext Before piping elements, this function is called as an opportunity * to pipe context data. Print line function (like out.println) will be * passed as printPipeContext's parameter. * @param printRDDElement Use this function to customize how to pipe elements. This function @@ -637,7 +664,7 @@ abstract class RDD[T: ClassTag]( env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, printRDDElement: (T, String => Unit) => Unit = null, - separateWorkingDir: Boolean = false): RDD[String] = { + separateWorkingDir: Boolean = false): RDD[String] = withScope { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null, @@ -651,7 +678,7 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitions[U: ClassTag]( - f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } @@ -664,7 +691,8 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } @@ -681,7 +709,7 @@ abstract class RDD[T: ClassTag]( @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { + preservesPartitioning: Boolean = false): RDD[U] = withScope { val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } @@ -692,7 +720,8 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { mapPartitionsWithIndex(f, preservesPartitioning) } @@ -704,7 +733,7 @@ abstract class RDD[T: ClassTag]( @deprecated("use mapPartitionsWithIndex", "1.0.0") def mapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => U): RDD[U] = { + (f: (T, A) => U): RDD[U] = withScope { mapPartitionsWithIndex((index, iter) => { val a = constructA(index) iter.map(t => f(t, a)) @@ -719,7 +748,7 @@ abstract class RDD[T: ClassTag]( @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") def flatMapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => Seq[U]): RDD[U] = { + (f: (T, A) => Seq[U]): RDD[U] = withScope { mapPartitionsWithIndex((index, iter) => { val a = constructA(index) iter.flatMap(t => f(t, a)) @@ -732,11 +761,11 @@ abstract class RDD[T: ClassTag]( * partition with the index of that partition. */ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") - def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit) { + def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { mapPartitionsWithIndex { (index, iter) => val a = constructA(index) iter.map(t => {f(t, a); t}) - }.foreach(_ => {}) + } } /** @@ -745,7 +774,7 @@ abstract class RDD[T: ClassTag]( * partition with the index of that partition. */ @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") - def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { mapPartitionsWithIndex((index, iter) => { val a = constructA(index) iter.filter(t => p(t, a)) @@ -758,7 +787,7 @@ abstract class RDD[T: ClassTag]( * partitions* and the *same number of elements in each partition* (e.g. one was made through * a map on the other). */ - def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { @@ -780,33 +809,39 @@ abstract class RDD[T: ClassTag]( */ def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning) + } def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B]) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, preservesPartitioning = false)(f) + } def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) + } def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C]) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, rdd3, preservesPartitioning = false)(f) + } def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) + } def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f) + } // Actions (launch a job to return a value to the user program) @@ -814,7 +849,7 @@ abstract class RDD[T: ClassTag]( /** * Applies a function f to all elements of this RDD. */ - def foreach(f: T => Unit) { + def foreach(f: T => Unit): Unit = withScope { val cleanF = sc.clean(f) sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } @@ -822,7 +857,7 @@ abstract class RDD[T: ClassTag]( /** * Applies a function f to each partition of this RDD. */ - def foreachPartition(f: Iterator[T] => Unit) { + def foreachPartition(f: Iterator[T] => Unit): Unit = withScope { val cleanF = sc.clean(f) sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } @@ -830,7 +865,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. */ - def collect(): Array[T] = { + def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) Array.concat(results: _*) } @@ -840,7 +875,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator: Iterator[T] = { + def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head } @@ -851,12 +886,14 @@ abstract class RDD[T: ClassTag]( * Return an array that contains all of the elements in this RDD. */ @deprecated("use collect", "1.0.0") - def toArray(): Array[T] = collect() + def toArray(): Array[T] = withScope { + collect() + } /** * Return an RDD that contains all matching values by applying `f`. */ - def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = { + def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope { filter(f.isDefinedAt).map(f) } @@ -866,19 +903,23 @@ abstract class RDD[T: ClassTag]( * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtract(other: RDD[T]): RDD[T] = + def subtract(other: RDD[T]): RDD[T] = withScope { subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) + } /** * Return an RDD with the elements from `this` that are not in `other`. */ - def subtract(other: RDD[T], numPartitions: Int): RDD[T] = + def subtract(other: RDD[T], numPartitions: Int): RDD[T] = withScope { subtract(other, new HashPartitioner(numPartitions)) + } /** * Return an RDD with the elements from `this` that are not in `other`. */ - def subtract(other: RDD[T], p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = { + def subtract( + other: RDD[T], + p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = withScope { if (partitioner == Some(p)) { // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples @@ -900,7 +941,7 @@ abstract class RDD[T: ClassTag]( * Reduces the elements of this RDD using the specified commutative and * associative binary operator. */ - def reduce(f: (T, T) => T): T = { + def reduce(f: (T, T) => T): T = withScope { val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { @@ -929,7 +970,7 @@ abstract class RDD[T: ClassTag]( * @param depth suggested depth of the tree (default: 2) * @see [[org.apache.spark.rdd.RDD#reduce]] */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + def treeReduce(f: (T, T) => T, depth: Int = 2): T = withScope { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") val cleanF = context.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { @@ -961,7 +1002,7 @@ abstract class RDD[T: ClassTag]( * modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t2. */ - def fold(zeroValue: T)(op: (T, T) => T): T = { + def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanOp = sc.clean(op) @@ -979,7 +1020,7 @@ abstract class RDD[T: ClassTag]( * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. */ - def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) @@ -999,26 +1040,29 @@ abstract class RDD[T: ClassTag]( def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, - depth: Int = 2): U = { + depth: Int = 2): U = withScope { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") if (partitions.length == 0) { - return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = context.clean(seqOp) - val cleanCombOp = context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.length - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } else { + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = + (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.length + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce + // the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { + (i, iter) => iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) } - partiallyAggregated.reduce(cleanCombOp) } /** @@ -1032,7 +1076,9 @@ abstract class RDD[T: ClassTag]( * within a timeout, even if not all tasks have finished. */ @Experimental - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def countApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope { val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => var result = 0L while (iter.hasNext) { @@ -1053,7 +1099,7 @@ abstract class RDD[T: ClassTag]( * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ - def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { + def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope { map(value => (value, null)).countByKey() } @@ -1064,8 +1110,7 @@ abstract class RDD[T: ClassTag]( @Experimental def countByValueApprox(timeout: Long, confidence: Double = 0.95) (implicit ord: Ordering[T] = null) - : PartialResult[Map[T, BoundedDouble]] = - { + : PartialResult[Map[T, BoundedDouble]] = withScope { if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } @@ -1098,7 +1143,7 @@ abstract class RDD[T: ClassTag]( * If `sp` equals 0, the sparse representation is skipped. */ @Experimental - def countApproxDistinct(p: Int, sp: Int): Long = { + def countApproxDistinct(p: Int, sp: Int): Long = withScope { require(p >= 4, s"p ($p) must be at least 4") require(sp <= 32, s"sp ($sp) cannot be greater than 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") @@ -1124,7 +1169,7 @@ abstract class RDD[T: ClassTag]( * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinct(relativeSD: Double = 0.05): Long = { + def countApproxDistinct(relativeSD: Double = 0.05): Long = withScope { val p = math.ceil(2.0 * math.log(1.054 / relativeSD) / math.log(2)).toInt countApproxDistinct(p, 0) } @@ -1142,7 +1187,9 @@ abstract class RDD[T: ClassTag]( * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ - def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this) + def zipWithIndex(): RDD[(T, Long)] = withScope { + new ZippedWithIndexRDD(this) + } /** * Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k, @@ -1154,7 +1201,7 @@ abstract class RDD[T: ClassTag]( * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ - def zipWithUniqueId(): RDD[(T, Long)] = { + def zipWithUniqueId(): RDD[(T, Long)] = withScope { val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => @@ -1171,48 +1218,50 @@ abstract class RDD[T: ClassTag]( * @note due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ - def take(num: Int): Array[T] = { + def take(num: Int): Array[T] = withScope { if (num == 0) { - return new Array[T](0) - } - - val buf = new ArrayBuffer[T] - val totalParts = this.partitions.length - var partsScanned = 0 - while (buf.size < num && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, - // interpolate the number of partitions we need to try, but overestimate it by 50%. - // We also cap the estimation in the end. - if (buf.size == 0) { - numPartsToTry = partsScanned * 4 - } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + new Array[T](0) + } else { + val buf = new ArrayBuffer[T] + val totalParts = this.partitions.length + var partsScanned = 0 + while (buf.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + if (buf.size == 0) { + numPartsToTry = partsScanned * 4 + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + } } - } - val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val left = num - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) - res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry - } + res.foreach(buf ++= _.take(num - buf.size)) + partsScanned += numPartsToTry + } - buf.toArray + buf.toArray + } } /** * Return the first element in this RDD. */ - def first(): T = take(1) match { - case Array(t) => t - case _ => throw new UnsupportedOperationException("empty collection") + def first(): T = withScope { + take(1) match { + case Array(t) => t + case _ => throw new UnsupportedOperationException("empty collection") + } } /** @@ -1230,7 +1279,9 @@ abstract class RDD[T: ClassTag]( * @param ord the implicit ordering for T * @return an array of top elements */ - def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope { + takeOrdered(num)(ord.reverse) + } /** * Returns the first k (smallest) elements from this RDD as defined by the specified @@ -1248,7 +1299,7 @@ abstract class RDD[T: ClassTag]( * @param ord the implicit ordering for T * @return an array of top elements */ - def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = { + def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope { if (num == 0) { Array.empty } else { @@ -1273,13 +1324,17 @@ abstract class RDD[T: ClassTag]( * Returns the max of this RDD as defined by the implicit Ordering[T]. * @return the maximum element of the RDD * */ - def max()(implicit ord: Ordering[T]): T = this.reduce(ord.max) + def max()(implicit ord: Ordering[T]): T = withScope { + this.reduce(ord.max) + } /** * Returns the min of this RDD as defined by the implicit Ordering[T]. * @return the minimum element of the RDD * */ - def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min) + def min()(implicit ord: Ordering[T]): T = withScope { + this.reduce(ord.min) + } /** * @note due to complications in the internal implementation, this method will raise an @@ -1289,12 +1344,14 @@ abstract class RDD[T: ClassTag]( * @return true if and only if the RDD contains no elements at all. Note that an RDD * may be empty even when it has at least 1 partition. */ - def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0 + def isEmpty(): Boolean = withScope { + partitions.length == 0 || take(1).length == 0 + } /** * Save this RDD as a text file, using string representations of elements. */ - def saveAsTextFile(path: String) { + def saveAsTextFile(path: String): Unit = withScope { // https://issues.apache.org/jira/browse/SPARK-2075 // // NullWritable is a `Comparable` in Hadoop 1.+, so the compiler cannot find an implicit @@ -1321,7 +1378,7 @@ abstract class RDD[T: ClassTag]( /** * Save this RDD as a compressed text file, using string representations of elements. */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]): Unit = withScope { // https://issues.apache.org/jira/browse/SPARK-2075 val nullWritableClassTag = implicitly[ClassTag[NullWritable]] val textClassTag = implicitly[ClassTag[Text]] @@ -1339,7 +1396,7 @@ abstract class RDD[T: ClassTag]( /** * Save this RDD as a SequenceFile of serialized objects. */ - def saveAsObjectFile(path: String) { + def saveAsObjectFile(path: String): Unit = withScope { this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) .saveAsSequenceFile(path) @@ -1348,12 +1405,12 @@ abstract class RDD[T: ClassTag]( /** * Creates tuples of the elements in this RDD by applying `f`. */ - def keyBy[K](f: T => K): RDD[(K, T)] = { + def keyBy[K](f: T => K): RDD[(K, T)] = withScope { map(x => (f(x), x)) } /** A private method for tests, to look at the contents of each partition */ - private[spark] def collectPartitions(): Array[Array[T]] = { + private[spark] def collectPartitions(): Array[Array[T]] = withScope { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } @@ -1392,6 +1449,17 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = sc.getCallSite() + /** + * The scope associated with the operation that created this RDD. + * + * This is more flexible than the call site and can be defined hierarchically. For more + * detail, see the documentation of {{RDDOperationScope}}. This scope is not defined if the + * user instantiates this RDD himself without using any Spark operations. + */ + @transient private[spark] val scope: Option[RDDOperationScope] = { + Option(sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)).map(RDDOperationScope.fromJson) + } + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] @@ -1470,7 +1538,7 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { // Get a debug description of an rdd without its children - def debugSelf (rdd: RDD[_]): Seq[String] = { + def debugSelf(rdd: RDD[_]): Seq[String] = { import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" @@ -1527,10 +1595,11 @@ abstract class RDD[T: ClassTag]( case (desc: String, _) => s"$nextPrefix$desc" } ++ debugChildren(rdd, nextPrefix) } - def debugString(rdd: RDD[_], - prefix: String = "", - isShuffle: Boolean = true, - isLastChild: Boolean = false): Seq[String] = { + def debugString( + rdd: RDD[_], + prefix: String = "", + isShuffle: Boolean = true, + isLastChild: Boolean = false): Seq[String] = { if (isShuffle) { shuffleDebugString(rdd, prefix, isLastChild) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala new file mode 100644 index 0000000000000..537b56b49f866 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.atomic.AtomicInteger + +import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOrder} +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.SparkContext + +/** + * A general, named code block representing an operation that instantiates RDDs. + * + * All RDDs instantiated in the corresponding code block will store a pointer to this object. + * Examples include, but will not be limited to, existing RDD operations, such as textFile, + * reduceByKey, and treeAggregate. + * + * An operation scope may be nested in other scopes. For instance, a SQL query may enclose + * scopes associated with the public RDD APIs it uses under the hood. + * + * There is no particular relationship between an operation scope and a stage or a job. + * A scope may live inside one stage (e.g. map) or span across multiple jobs (e.g. take). + */ +@JsonInclude(Include.NON_NULL) +@JsonPropertyOrder(Array("id", "name", "parent")) +private[spark] class RDDOperationScope( + val name: String, + val parent: Option[RDDOperationScope] = None) { + + val id: Int = RDDOperationScope.nextScopeId() + + def toJson: String = { + RDDOperationScope.jsonMapper.writeValueAsString(this) + } + + /** + * Return a list of scopes that this scope is a part of, including this scope itself. + * The result is ordered from the outermost scope (eldest ancestor) to this scope. + */ + @JsonIgnore + def getAllScopes: Seq[RDDOperationScope] = { + parent.map(_.getAllScopes).getOrElse(Seq.empty) ++ Seq(this) + } + + override def equals(other: Any): Boolean = { + other match { + case s: RDDOperationScope => + id == s.id && name == s.name && parent == s.parent + case _ => false + } + } + + override def toString: String = toJson +} + +/** + * A collection of utility methods to construct a hierarchical representation of RDD scopes. + * An RDD scope tracks the series of operations that created a given RDD. + */ +private[spark] object RDDOperationScope { + private val jsonMapper = new ObjectMapper().registerModule(DefaultScalaModule) + private val scopeCounter = new AtomicInteger(0) + + def fromJson(s: String): RDDOperationScope = { + jsonMapper.readValue(s, classOf[RDDOperationScope]) + } + + /** Return a globally unique operation scope ID. */ + def nextScopeId(): Int = scopeCounter.getAndIncrement + + /** + * Execute the given body such that all RDDs created in this body will have the same scope. + * The name of the scope will be the name of the method that immediately encloses this one. + * + * Note: Return statements are NOT allowed in body. + */ + private[spark] def withScope[T]( + sc: SparkContext, + allowNesting: Boolean = false)(body: => T): T = { + val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName + withScope[T](sc, callerMethodName, allowNesting)(body) + } + + /** + * Execute the given body such that all RDDs created in this body will have the same scope. + * + * If nesting is allowed, this concatenates the previous scope with the new one in a way that + * signifies the hierarchy. Otherwise, if nesting is not allowed, then any children calls to + * this method executed in the body will have no effect. + * + * Note: Return statements are NOT allowed in body. + */ + private[spark] def withScope[T]( + sc: SparkContext, + name: String, + allowNesting: Boolean = false)(body: => T): T = { + // Save the old scope to restore it later + val scopeKey = SparkContext.RDD_SCOPE_KEY + val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY + val oldScopeJson = sc.getLocalProperty(scopeKey) + val oldScope = Option(oldScopeJson).map(RDDOperationScope.fromJson) + val oldNoOverride = sc.getLocalProperty(noOverrideKey) + try { + // Set the scope only if the higher level caller allows us to do so + if (sc.getLocalProperty(noOverrideKey) == null) { + sc.setLocalProperty(scopeKey, new RDDOperationScope(name, oldScope).toJson) + } + // Optionally disallow the child body to override our scope + if (!allowNesting) { + sc.setLocalProperty(noOverrideKey, "true") + } + body + } finally { + // Remember to restore any state that was modified before exiting + sc.setLocalProperty(scopeKey, oldScopeJson) + sc.setLocalProperty(noOverrideKey, oldNoOverride) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 059f8963691f0..3dfcf67f0eb66 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -85,7 +85,9 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { + def saveAsSequenceFile( + path: String, + codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope { def anyToWritable[U <% Writable](u: U): Writable = u // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index cf3db0b94a0b3..e439d2a7e1229 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,6 +33,7 @@ class StageInfo( val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], + val parentIds: Seq[Int], val details: String) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None @@ -78,6 +79,7 @@ private[spark] object StageInfo { stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, + stage.parents.map(_.id), stage.details) } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index ad53a3edc7cc1..96062626b5045 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDDOperationScope, RDD} import org.apache.spark.util.Utils @DeveloperApi @@ -26,7 +26,9 @@ class RDDInfo( val id: Int, val name: String, val numPartitions: Int, - var storageLevel: StorageLevel) + var storageLevel: StorageLevel, + val parentIds: Seq[Int], + val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { var numCachedPartitions = 0 @@ -52,7 +54,8 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { - val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) + val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) + val parentIds = rdd.dependencies.map(_.rdd.id) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 06fce86bd38d2..a5271f0574e6c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -25,6 +25,7 @@ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener /** * Top level user interface for a Spark application. @@ -38,6 +39,7 @@ private[spark] class SparkUI private ( val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, val storageListener: StorageListener, + val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String) extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") @@ -93,6 +95,9 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val DEFAULT_POOL_NAME = "default" + val DEFAULT_RETAINED_STAGES = 1000 + val DEFAULT_RETAINED_JOBS = 1000 def getUIPort(conf: SparkConf): Int = { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) @@ -144,13 +149,16 @@ private[spark] object SparkUI { val storageStatusListener = new StorageStatusListener val executorsListener = new ExecutorsListener(storageStatusListener) val storageListener = new StorageListener(storageStatusListener) + val operationGraphListener = new RDDOperationGraphListener(conf) listenerBus.addListener(environmentListener) listenerBus.addListener(storageStatusListener) listenerBus.addListener(executorsListener) listenerBus.addListener(storageListener) + listenerBus.addListener(operationGraphListener) new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, _jobProgressListener, storageListener, appName, basePath) + executorsListener, _jobProgressListener, storageListener, operationGraphListener, + appName, basePath) } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 395af2ea30b9d..2f3fb181e4026 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -23,6 +23,7 @@ import java.util.{Locale, Date} import scala.xml.{Node, Text} import org.apache.spark.Logging +import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { @@ -172,13 +173,21 @@ private[spark] object UIUtils extends Logging { } + def vizHeaderNodes: Seq[Node] = { + + + + + } + /** Returns a spark page with correctly formatted headers */ def headerSparkPage( title: String, content: => Seq[Node], activeTab: SparkUITab, refreshInterval: Option[Int] = None, - helpText: Option[String] = None): Seq[Node] = { + helpText: Option[String] = None, + showVisualization: Boolean = false): Seq[Node] = { val appName = activeTab.appName val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." @@ -196,6 +205,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} + {if (showVisualization) vizHeaderNodes else Seq.empty} {appName} - {title} @@ -320,4 +330,47 @@ private[spark] object UIUtils extends Logging {
} + + /** Return a "DAG visualization" DOM element that expands into a visualization for a stage. */ + def showDagVizForStage(stageId: Int, graph: Option[RDDOperationGraph]): Seq[Node] = { + showDagViz(graph.toSeq, forJob = false) + } + + /** Return a "DAG visualization" DOM element that expands into a visualization for a job. */ + def showDagVizForJob(jobId: Int, graphs: Seq[RDDOperationGraph]): Seq[Node] = { + showDagViz(graphs, forJob = true) + } + + /** + * Return a "DAG visualization" DOM element that expands into a visualization on the UI. + * + * This populates metadata necessary for generating the visualization on the front-end in + * a format that is expected by spark-dag-viz.js. Any changes in the format here must be + * reflected there. + */ + private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = { +
+ + + DAG visualization + +
+
+ { + graphs.map { g => + + } + } +
+
+ } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a7ea12b1655fe..f6abf27db49dd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -179,7 +179,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { - Event Timeline + Event timeline ++