diff --git a/.gitignore b/.gitignore index 4f177c82ae5e0..a4ec12ca6b53f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ conf/spark-env.sh conf/streaming-env.sh conf/log4j.properties conf/spark-defaults.conf +conf/hive-site.xml docs/_site docs/api target/ @@ -50,6 +51,7 @@ unit-tests.log rat-results.txt scalastyle.txt conf/*.conf +scalastyle-output.xml # For Hive metastore_db/ diff --git a/LICENSE b/LICENSE index 383f079df8c8b..65e1f480d9b14 100644 --- a/LICENSE +++ b/LICENSE @@ -442,7 +442,7 @@ Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, An ======================================================================== -Fo SnapTree: +For SnapTree: ======================================================================== SNAPTREE LICENSE @@ -482,6 +482,24 @@ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +======================================================================== +For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +======================================================================== +Copyright (C) 2008 The Android Open Source Project + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/README.md b/README.md index f6e7f51091314..f87e07aa5cc90 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@ # Apache Spark -Lightning-Fast Cluster Computing - +Spark is a fast and general cluster computing system for Big Data. It provides +high-level APIs in Scala, Java, and Python, and an optimized engine that +supports general computation graphs for data analysis. It also supports a +rich set of higher-level tools including Spark SQL for SQL and structured +data processing, MLLib for machine learning, GraphX for graph processing, +and Spark Streaming. + + ## Online Documentation diff --git a/assembly/pom.xml b/assembly/pom.xml index 4f6aade133db7..567a8dd2a0d94 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -39,6 +39,7 @@ spark /usr/share/spark root + 744 @@ -276,7 +277,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/bin - 744 + ${deb.bin.filemode} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 70a99b33d753c..ef0bb2ac13f08 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -72,6 +72,7 @@ object Bagel extends Logging { var verts = vertices var msgs = messages var noActivity = false + var lastRDD: RDD[(K, (V, Array[M]))] = null do { logInfo("Starting superstep " + superstep + ".") val startTime = System.currentTimeMillis @@ -83,6 +84,10 @@ object Bagel extends Logging { val superstep_ = superstep // Create a read-only copy of superstep for capture in closure val (processed, numMsgs, numActiveVerts) = comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) + if (lastRDD != null) { + lastRDD.unpersist(false) + } + lastRDD = processed val timeTaken = System.currentTimeMillis - startTime logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) diff --git a/core/pom.xml b/core/pom.xml index 4ed920a750fff..1054cec4d77bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -114,6 +114,10 @@ org.xerial.snappy snappy-java + + net.jpountz.lz4 + lz4 + com.twitter chill_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/Sorter.java new file mode 100644 index 0000000000000..64ad18c0e463a --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/Sorter.java @@ -0,0 +1,915 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection; + +import java.util.Comparator; + +/** + * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort." + * See the method comment on sort() for more details. + * + * This has been kept in Java with the original style in order to match very closely with the + * Anroid source code, and thus be easy to verify correctness. + * + * The purpose of the port is to generalize the interface to the sort to accept input data formats + * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap + * uses this to sort an Array with alternating elements of the form [key, value, key, value]. + * This generalization comes with minimal overhead -- see SortDataFormat for more information. + */ +class Sorter { + + /** + * This is the minimum sized sequence that will be merged. Shorter + * sequences will be lengthened by calling binarySort. If the entire + * array is less than this length, no merges will be performed. + * + * This constant should be a power of two. It was 64 in Tim Peter's C + * implementation, but 32 was empirically determined to work better in + * this implementation. In the unlikely event that you set this constant + * to be a number that's not a power of two, you'll need to change the + * minRunLength computation. + * + * If you decrease this constant, you must change the stackLen + * computation in the TimSort constructor, or you risk an + * ArrayOutOfBounds exception. See listsort.txt for a discussion + * of the minimum stack length required as a function of the length + * of the array being sorted and the minimum merge sequence length. + */ + private static final int MIN_MERGE = 32; + + private final SortDataFormat s; + + public Sorter(SortDataFormat sortDataFormat) { + this.s = sortDataFormat; + } + + /** + * A stable, adaptive, iterative mergesort that requires far fewer than + * n lg(n) comparisons when running on partially sorted arrays, while + * offering performance comparable to a traditional mergesort when run + * on random arrays. Like all proper mergesorts, this sort is stable and + * runs O(n log n) time (worst case). In the worst case, this sort requires + * temporary storage space for n/2 object references; in the best case, + * it requires only a small constant amount of space. + * + * This implementation was adapted from Tim Peters's list sort for + * Python, which is described in detail here: + * + * http://svn.python.org/projects/python/trunk/Objects/listsort.txt + * + * Tim's C code may be found here: + * + * http://svn.python.org/projects/python/trunk/Objects/listobject.c + * + * The underlying techniques are described in this paper (and may have + * even earlier origins): + * + * "Optimistic Sorting and Information Theoretic Complexity" + * Peter McIlroy + * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms), + * pp 467-474, Austin, Texas, 25-27 January 1993. + * + * While the API to this class consists solely of static methods, it is + * (privately) instantiable; a TimSort instance holds the state of an ongoing + * sort, assuming the input array is large enough to warrant the full-blown + * TimSort. Small arrays are sorted in place, using a binary insertion sort. + * + * @author Josh Bloch + */ + void sort(Buffer a, int lo, int hi, Comparator c) { + assert c != null; + + int nRemaining = hi - lo; + if (nRemaining < 2) + return; // Arrays of size 0 and 1 are always sorted + + // If array is small, do a "mini-TimSort" with no merges + if (nRemaining < MIN_MERGE) { + int initRunLen = countRunAndMakeAscending(a, lo, hi, c); + binarySort(a, lo, hi, lo + initRunLen, c); + return; + } + + /** + * March over the array once, left to right, finding natural runs, + * extending short natural runs to minRun elements, and merging runs + * to maintain stack invariant. + */ + SortState sortState = new SortState(a, c, hi - lo); + int minRun = minRunLength(nRemaining); + do { + // Identify next run + int runLen = countRunAndMakeAscending(a, lo, hi, c); + + // If run is short, extend to min(minRun, nRemaining) + if (runLen < minRun) { + int force = nRemaining <= minRun ? nRemaining : minRun; + binarySort(a, lo, lo + force, lo + runLen, c); + runLen = force; + } + + // Push run onto pending-run stack, and maybe merge + sortState.pushRun(lo, runLen); + sortState.mergeCollapse(); + + // Advance to find next run + lo += runLen; + nRemaining -= runLen; + } while (nRemaining != 0); + + // Merge all remaining runs to complete sort + assert lo == hi; + sortState.mergeForceCollapse(); + assert sortState.stackSize == 1; + } + + /** + * Sorts the specified portion of the specified array using a binary + * insertion sort. This is the best method for sorting small numbers + * of elements. It requires O(n log n) compares, but O(n^2) data + * movement (worst case). + * + * If the initial part of the specified range is already sorted, + * this method can take advantage of it: the method assumes that the + * elements from index {@code lo}, inclusive, to {@code start}, + * exclusive are already sorted. + * + * @param a the array in which a range is to be sorted + * @param lo the index of the first element in the range to be sorted + * @param hi the index after the last element in the range to be sorted + * @param start the index of the first element in the range that is + * not already known to be sorted ({@code lo <= start <= hi}) + * @param c comparator to used for the sort + */ + @SuppressWarnings("fallthrough") + private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { + assert lo <= start && start <= hi; + if (start == lo) + start++; + + Buffer pivotStore = s.allocate(1); + for ( ; start < hi; start++) { + s.copyElement(a, start, pivotStore, 0); + K pivot = s.getKey(pivotStore, 0); + + // Set left (and right) to the index where a[start] (pivot) belongs + int left = lo; + int right = start; + assert left <= right; + /* + * Invariants: + * pivot >= all in [lo, left). + * pivot < all in [right, start). + */ + while (left < right) { + int mid = (left + right) >>> 1; + if (c.compare(pivot, s.getKey(a, mid)) < 0) + right = mid; + else + left = mid + 1; + } + assert left == right; + + /* + * The invariants still hold: pivot >= all in [lo, left) and + * pivot < all in [left, start), so pivot belongs at left. Note + * that if there are elements equal to pivot, left points to the + * first slot after them -- that's why this sort is stable. + * Slide elements over to make room for pivot. + */ + int n = start - left; // The number of elements to move + // Switch is just an optimization for arraycopy in default case + switch (n) { + case 2: s.copyElement(a, left + 1, a, left + 2); + case 1: s.copyElement(a, left, a, left + 1); + break; + default: s.copyRange(a, left, a, left + 1, n); + } + s.copyElement(pivotStore, 0, a, left); + } + } + + /** + * Returns the length of the run beginning at the specified position in + * the specified array and reverses the run if it is descending (ensuring + * that the run will always be ascending when the method returns). + * + * A run is the longest ascending sequence with: + * + * a[lo] <= a[lo + 1] <= a[lo + 2] <= ... + * + * or the longest descending sequence with: + * + * a[lo] > a[lo + 1] > a[lo + 2] > ... + * + * For its intended use in a stable mergesort, the strictness of the + * definition of "descending" is needed so that the call can safely + * reverse a descending sequence without violating stability. + * + * @param a the array in which a run is to be counted and possibly reversed + * @param lo index of the first element in the run + * @param hi index after the last element that may be contained in the run. + It is required that {@code lo < hi}. + * @param c the comparator to used for the sort + * @return the length of the run beginning at the specified position in + * the specified array + */ + private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator c) { + assert lo < hi; + int runHi = lo + 1; + if (runHi == hi) + return 1; + + // Find end of run, and reverse range if descending + if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending + while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0) + runHi++; + reverseRange(a, lo, runHi); + } else { // Ascending + while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0) + runHi++; + } + + return runHi - lo; + } + + /** + * Reverse the specified range of the specified array. + * + * @param a the array in which a range is to be reversed + * @param lo the index of the first element in the range to be reversed + * @param hi the index after the last element in the range to be reversed + */ + private void reverseRange(Buffer a, int lo, int hi) { + hi--; + while (lo < hi) { + s.swap(a, lo, hi); + lo++; + hi--; + } + } + + /** + * Returns the minimum acceptable run length for an array of the specified + * length. Natural runs shorter than this will be extended with + * {@link #binarySort}. + * + * Roughly speaking, the computation is: + * + * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff). + * Else if n is an exact power of 2, return MIN_MERGE/2. + * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k + * is close to, but strictly less than, an exact power of 2. + * + * For the rationale, see listsort.txt. + * + * @param n the length of the array to be sorted + * @return the length of the minimum run to be merged + */ + private int minRunLength(int n) { + assert n >= 0; + int r = 0; // Becomes 1 if any 1 bits are shifted off + while (n >= MIN_MERGE) { + r |= (n & 1); + n >>= 1; + } + return n + r; + } + + private class SortState { + + /** + * The Buffer being sorted. + */ + private final Buffer a; + + /** + * Length of the sort Buffer. + */ + private final int aLength; + + /** + * The comparator for this sort. + */ + private final Comparator c; + + /** + * When we get into galloping mode, we stay there until both runs win less + * often than MIN_GALLOP consecutive times. + */ + private static final int MIN_GALLOP = 7; + + /** + * This controls when we get *into* galloping mode. It is initialized + * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for + * random data, and lower for highly structured data. + */ + private int minGallop = MIN_GALLOP; + + /** + * Maximum initial size of tmp array, which is used for merging. The array + * can grow to accommodate demand. + * + * Unlike Tim's original C version, we do not allocate this much storage + * when sorting smaller arrays. This change was required for performance. + */ + private static final int INITIAL_TMP_STORAGE_LENGTH = 256; + + /** + * Temp storage for merges. + */ + private Buffer tmp; // Actual runtime type will be Object[], regardless of T + + /** + * Length of the temp storage. + */ + private int tmpLength = 0; + + /** + * A stack of pending runs yet to be merged. Run i starts at + * address base[i] and extends for len[i] elements. It's always + * true (so long as the indices are in bounds) that: + * + * runBase[i] + runLen[i] == runBase[i + 1] + * + * so we could cut the storage for this, but it's a minor amount, + * and keeping all the info explicit simplifies the code. + */ + private int stackSize = 0; // Number of pending runs on stack + private final int[] runBase; + private final int[] runLen; + + /** + * Creates a TimSort instance to maintain the state of an ongoing sort. + * + * @param a the array to be sorted + * @param c the comparator to determine the order of the sort + */ + private SortState(Buffer a, Comparator c, int len) { + this.aLength = len; + this.a = a; + this.c = c; + + // Allocate temp storage (which may be increased later if necessary) + tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH; + tmp = s.allocate(tmpLength); + + /* + * Allocate runs-to-be-merged stack (which cannot be expanded). The + * stack length requirements are described in listsort.txt. The C + * version always uses the same stack length (85), but this was + * measured to be too expensive when sorting "mid-sized" arrays (e.g., + * 100 elements) in Java. Therefore, we use smaller (but sufficiently + * large) stack lengths for smaller arrays. The "magic numbers" in the + * computation below must be changed if MIN_MERGE is decreased. See + * the MIN_MERGE declaration above for more information. + */ + int stackLen = (len < 120 ? 5 : + len < 1542 ? 10 : + len < 119151 ? 19 : 40); + runBase = new int[stackLen]; + runLen = new int[stackLen]; + } + + /** + * Pushes the specified run onto the pending-run stack. + * + * @param runBase index of the first element in the run + * @param runLen the number of elements in the run + */ + private void pushRun(int runBase, int runLen) { + this.runBase[stackSize] = runBase; + this.runLen[stackSize] = runLen; + stackSize++; + } + + /** + * Examines the stack of runs waiting to be merged and merges adjacent runs + * until the stack invariants are reestablished: + * + * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1] + * 2. runLen[i - 2] > runLen[i - 1] + * + * This method is called each time a new run is pushed onto the stack, + * so the invariants are guaranteed to hold for i < stackSize upon + * entry to the method. + */ + private void mergeCollapse() { + while (stackSize > 1) { + int n = stackSize - 2; + if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { + if (runLen[n - 1] < runLen[n + 1]) + n--; + mergeAt(n); + } else if (runLen[n] <= runLen[n + 1]) { + mergeAt(n); + } else { + break; // Invariant is established + } + } + } + + /** + * Merges all runs on the stack until only one remains. This method is + * called once, to complete the sort. + */ + private void mergeForceCollapse() { + while (stackSize > 1) { + int n = stackSize - 2; + if (n > 0 && runLen[n - 1] < runLen[n + 1]) + n--; + mergeAt(n); + } + } + + /** + * Merges the two runs at stack indices i and i+1. Run i must be + * the penultimate or antepenultimate run on the stack. In other words, + * i must be equal to stackSize-2 or stackSize-3. + * + * @param i stack index of the first of the two runs to merge + */ + private void mergeAt(int i) { + assert stackSize >= 2; + assert i >= 0; + assert i == stackSize - 2 || i == stackSize - 3; + + int base1 = runBase[i]; + int len1 = runLen[i]; + int base2 = runBase[i + 1]; + int len2 = runLen[i + 1]; + assert len1 > 0 && len2 > 0; + assert base1 + len1 == base2; + + /* + * Record the length of the combined runs; if i is the 3rd-last + * run now, also slide over the last run (which isn't involved + * in this merge). The current run (i+1) goes away in any case. + */ + runLen[i] = len1 + len2; + if (i == stackSize - 3) { + runBase[i + 1] = runBase[i + 2]; + runLen[i + 1] = runLen[i + 2]; + } + stackSize--; + + /* + * Find where the first element of run2 goes in run1. Prior elements + * in run1 can be ignored (because they're already in place). + */ + int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c); + assert k >= 0; + base1 += k; + len1 -= k; + if (len1 == 0) + return; + + /* + * Find where the last element of run1 goes in run2. Subsequent elements + * in run2 can be ignored (because they're already in place). + */ + len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c); + assert len2 >= 0; + if (len2 == 0) + return; + + // Merge remaining runs, using tmp array with min(len1, len2) elements + if (len1 <= len2) + mergeLo(base1, len1, base2, len2); + else + mergeHi(base1, len1, base2, len2); + } + + /** + * Locates the position at which to insert the specified key into the + * specified sorted range; if the range contains an element equal to key, + * returns the index of the leftmost equal element. + * + * @param key the key whose insertion point to search for + * @param a the array in which to search + * @param base the index of the first element in the range + * @param len the length of the range; must be > 0 + * @param hint the index at which to begin the search, 0 <= hint < n. + * The closer hint is to the result, the faster this method will run. + * @param c the comparator used to order the range, and to search + * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k], + * pretending that a[b - 1] is minus infinity and a[b + n] is infinity. + * In other words, key belongs at index b + k; or in other words, + * the first k elements of a should precede key, and the last n - k + * should follow it. + */ + private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator c) { + assert len > 0 && hint >= 0 && hint < len; + int lastOfs = 0; + int ofs = 1; + if (c.compare(key, s.getKey(a, base + hint)) > 0) { + // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] + int maxOfs = len - hint; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to base + lastOfs += hint; + ofs += hint; + } else { // key <= a[base + hint] + // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] + final int maxOfs = hint + 1; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to base + int tmp = lastOfs; + lastOfs = hint - ofs; + ofs = hint - tmp; + } + assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; + + /* + * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere + * to the right of lastOfs but no farther right than ofs. Do a binary + * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs]. + */ + lastOfs++; + while (lastOfs < ofs) { + int m = lastOfs + ((ofs - lastOfs) >>> 1); + + if (c.compare(key, s.getKey(a, base + m)) > 0) + lastOfs = m + 1; // a[base + m] < key + else + ofs = m; // key <= a[base + m] + } + assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] + return ofs; + } + + /** + * Like gallopLeft, except that if the range contains an element equal to + * key, gallopRight returns the index after the rightmost equal element. + * + * @param key the key whose insertion point to search for + * @param a the array in which to search + * @param base the index of the first element in the range + * @param len the length of the range; must be > 0 + * @param hint the index at which to begin the search, 0 <= hint < n. + * The closer hint is to the result, the faster this method will run. + * @param c the comparator used to order the range, and to search + * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k] + */ + private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator c) { + assert len > 0 && hint >= 0 && hint < len; + + int ofs = 1; + int lastOfs = 0; + if (c.compare(key, s.getKey(a, base + hint)) < 0) { + // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] + int maxOfs = hint + 1; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to b + int tmp = lastOfs; + lastOfs = hint - ofs; + ofs = hint - tmp; + } else { // a[b + hint] <= key + // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] + int maxOfs = len - hint; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to b + lastOfs += hint; + ofs += hint; + } + assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; + + /* + * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to + * the right of lastOfs but no farther right than ofs. Do a binary + * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs]. + */ + lastOfs++; + while (lastOfs < ofs) { + int m = lastOfs + ((ofs - lastOfs) >>> 1); + + if (c.compare(key, s.getKey(a, base + m)) < 0) + ofs = m; // key < a[b + m] + else + lastOfs = m + 1; // a[b + m] <= key + } + assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] + return ofs; + } + + /** + * Merges two adjacent runs in place, in a stable fashion. The first + * element of the first run must be greater than the first element of the + * second run (a[base1] > a[base2]), and the last element of the first run + * (a[base1 + len1-1]) must be greater than all elements of the second run. + * + * For performance, this method should be called only when len1 <= len2; + * its twin, mergeHi should be called if len1 >= len2. (Either method + * may be called if len1 == len2.) + * + * @param base1 index of first element in first run to be merged + * @param len1 length of first run to be merged (must be > 0) + * @param base2 index of first element in second run to be merged + * (must be aBase + aLen) + * @param len2 length of second run to be merged (must be > 0) + */ + private void mergeLo(int base1, int len1, int base2, int len2) { + assert len1 > 0 && len2 > 0 && base1 + len1 == base2; + + // Copy first run into temp array + Buffer a = this.a; // For performance + Buffer tmp = ensureCapacity(len1); + s.copyRange(a, base1, tmp, 0, len1); + + int cursor1 = 0; // Indexes into tmp array + int cursor2 = base2; // Indexes int a + int dest = base1; // Indexes int a + + // Move first element of second run and deal with degenerate cases + s.copyElement(a, cursor2++, a, dest++); + if (--len2 == 0) { + s.copyRange(tmp, cursor1, a, dest, len1); + return; + } + if (len1 == 1) { + s.copyRange(a, cursor2, a, dest, len2); + s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge + return; + } + + Comparator c = this.c; // Use local variable for performance + int minGallop = this.minGallop; // " " " " " + outer: + while (true) { + int count1 = 0; // Number of times in a row that first run won + int count2 = 0; // Number of times in a row that second run won + + /* + * Do the straightforward thing until (if ever) one run starts + * winning consistently. + */ + do { + assert len1 > 1 && len2 > 0; + if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) { + s.copyElement(a, cursor2++, a, dest++); + count2++; + count1 = 0; + if (--len2 == 0) + break outer; + } else { + s.copyElement(tmp, cursor1++, a, dest++); + count1++; + count2 = 0; + if (--len1 == 1) + break outer; + } + } while ((count1 | count2) < minGallop); + + /* + * One run is winning so consistently that galloping may be a + * huge win. So try that, and continue galloping until (if ever) + * neither run appears to be winning consistently anymore. + */ + do { + assert len1 > 1 && len2 > 0; + count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c); + if (count1 != 0) { + s.copyRange(tmp, cursor1, a, dest, count1); + dest += count1; + cursor1 += count1; + len1 -= count1; + if (len1 <= 1) // len1 == 1 || len1 == 0 + break outer; + } + s.copyElement(a, cursor2++, a, dest++); + if (--len2 == 0) + break outer; + + count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c); + if (count2 != 0) { + s.copyRange(a, cursor2, a, dest, count2); + dest += count2; + cursor2 += count2; + len2 -= count2; + if (len2 == 0) + break outer; + } + s.copyElement(tmp, cursor1++, a, dest++); + if (--len1 == 1) + break outer; + minGallop--; + } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); + if (minGallop < 0) + minGallop = 0; + minGallop += 2; // Penalize for leaving gallop mode + } // End of "outer" loop + this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field + + if (len1 == 1) { + assert len2 > 0; + s.copyRange(a, cursor2, a, dest, len2); + s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge + } else if (len1 == 0) { + throw new IllegalArgumentException( + "Comparison method violates its general contract!"); + } else { + assert len2 == 0; + assert len1 > 1; + s.copyRange(tmp, cursor1, a, dest, len1); + } + } + + /** + * Like mergeLo, except that this method should be called only if + * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method + * may be called if len1 == len2.) + * + * @param base1 index of first element in first run to be merged + * @param len1 length of first run to be merged (must be > 0) + * @param base2 index of first element in second run to be merged + * (must be aBase + aLen) + * @param len2 length of second run to be merged (must be > 0) + */ + private void mergeHi(int base1, int len1, int base2, int len2) { + assert len1 > 0 && len2 > 0 && base1 + len1 == base2; + + // Copy second run into temp array + Buffer a = this.a; // For performance + Buffer tmp = ensureCapacity(len2); + s.copyRange(a, base2, tmp, 0, len2); + + int cursor1 = base1 + len1 - 1; // Indexes into a + int cursor2 = len2 - 1; // Indexes into tmp array + int dest = base2 + len2 - 1; // Indexes into a + + // Move last element of first run and deal with degenerate cases + s.copyElement(a, cursor1--, a, dest--); + if (--len1 == 0) { + s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); + return; + } + if (len2 == 1) { + dest -= len1; + cursor1 -= len1; + s.copyRange(a, cursor1 + 1, a, dest + 1, len1); + s.copyElement(tmp, cursor2, a, dest); + return; + } + + Comparator c = this.c; // Use local variable for performance + int minGallop = this.minGallop; // " " " " " + outer: + while (true) { + int count1 = 0; // Number of times in a row that first run won + int count2 = 0; // Number of times in a row that second run won + + /* + * Do the straightforward thing until (if ever) one run + * appears to win consistently. + */ + do { + assert len1 > 0 && len2 > 1; + if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) { + s.copyElement(a, cursor1--, a, dest--); + count1++; + count2 = 0; + if (--len1 == 0) + break outer; + } else { + s.copyElement(tmp, cursor2--, a, dest--); + count2++; + count1 = 0; + if (--len2 == 1) + break outer; + } + } while ((count1 | count2) < minGallop); + + /* + * One run is winning so consistently that galloping may be a + * huge win. So try that, and continue galloping until (if ever) + * neither run appears to be winning consistently anymore. + */ + do { + assert len1 > 0 && len2 > 1; + count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c); + if (count1 != 0) { + dest -= count1; + cursor1 -= count1; + len1 -= count1; + s.copyRange(a, cursor1 + 1, a, dest + 1, count1); + if (len1 == 0) + break outer; + } + s.copyElement(tmp, cursor2--, a, dest--); + if (--len2 == 1) + break outer; + + count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c); + if (count2 != 0) { + dest -= count2; + cursor2 -= count2; + len2 -= count2; + s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); + if (len2 <= 1) // len2 == 1 || len2 == 0 + break outer; + } + s.copyElement(a, cursor1--, a, dest--); + if (--len1 == 0) + break outer; + minGallop--; + } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); + if (minGallop < 0) + minGallop = 0; + minGallop += 2; // Penalize for leaving gallop mode + } // End of "outer" loop + this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field + + if (len2 == 1) { + assert len1 > 0; + dest -= len1; + cursor1 -= len1; + s.copyRange(a, cursor1 + 1, a, dest + 1, len1); + s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge + } else if (len2 == 0) { + throw new IllegalArgumentException( + "Comparison method violates its general contract!"); + } else { + assert len1 == 0; + assert len2 > 0; + s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); + } + } + + /** + * Ensures that the external array tmp has at least the specified + * number of elements, increasing its size if necessary. The size + * increases exponentially to ensure amortized linear time complexity. + * + * @param minCapacity the minimum required capacity of the tmp array + * @return tmp, whether or not it grew + */ + private Buffer ensureCapacity(int minCapacity) { + if (tmpLength < minCapacity) { + // Compute smallest power of 2 > minCapacity + int newSize = minCapacity; + newSize |= newSize >> 1; + newSize |= newSize >> 2; + newSize |= newSize >> 4; + newSize |= newSize >> 8; + newSize |= newSize >> 16; + newSize++; + + if (newSize < 0) // Not bloody likely! + newSize = minCapacity; + else + newSize = Math.min(newSize, aLength >>> 1); + + tmp = s.allocate(newSize); + tmpLength = newSize; + } + return tmp; + } + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 59fdf659c9e11..ff0ca11749d42 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -55,10 +55,7 @@ case class Aggregator[K, V, C] ( combiners.iterator } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - while (iter.hasNext) { - val (k, v) = iter.next() - combiners.insert(k, v) - } + combiners.insertAll(iter) // TODO: Make this non optional in a future release Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) @@ -85,8 +82,8 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) while (iter.hasNext) { - val (k, c) = iter.next() - combiners.insert(k, c) + val pair = iter.next() + combiners.insert(pair._1, pair._2) } // TODO: Make this non optional in a future release Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 8f867686a0443..5ddda4d6953fa 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -17,9 +17,9 @@ package org.apache.spark -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -30,7 +30,7 @@ import org.apache.spark.storage._ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { /** Keys of RDD partitions that are being computed/loaded. */ - private val loading = new HashSet[RDDBlockId]() + private val loading = new mutable.HashSet[RDDBlockId] /** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T]( @@ -118,21 +118,29 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } /** - * Cache the values of a partition, keeping track of any updates in the storage statuses - * of other blocks along the way. + * Cache the values of a partition, keeping track of any updates in the storage statuses of + * other blocks along the way. + * + * The effective storage level refers to the level that actually specifies BlockManager put + * behavior, not the level originally specified by the user. This is mainly for forcing a + * MEMORY_AND_DISK partition to disk if there is not enough room to unroll the partition, + * while preserving the the original semantics of the RDD as specified by the application. */ private def putInBlockManager[T]( key: BlockId, values: Iterator[T], - storageLevel: StorageLevel, - updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = { - - if (!storageLevel.useMemory) { - /* This RDD is not to be cached in memory, so we can just pass the computed values - * as an iterator directly to the BlockManager, rather than first fully unrolling - * it in memory. The latter option potentially uses much more memory and risks OOM - * exceptions that can be avoided. */ - updatedBlocks ++= blockManager.put(key, values, storageLevel, tellMaster = true) + level: StorageLevel, + updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)], + effectiveStorageLevel: Option[StorageLevel] = None): Iterator[T] = { + + val putLevel = effectiveStorageLevel.getOrElse(level) + if (!putLevel.useMemory) { + /* + * This RDD is not to be cached in memory, so we can just pass the computed values as an + * iterator directly to the BlockManager rather than first fully unrolling it in memory. + */ + updatedBlocks ++= + blockManager.putIterator(key, values, level, tellMaster = true, effectiveStorageLevel) blockManager.get(key) match { case Some(v) => v.data.asInstanceOf[Iterator[T]] case None => @@ -140,14 +148,36 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { throw new BlockException(key, s"Block manager failed to return cached value for $key!") } } else { - /* This RDD is to be cached in memory. In this case we cannot pass the computed values + /* + * This RDD is to be cached in memory. In this case we cannot pass the computed values * to the BlockManager as an iterator and expect to read it back later. This is because - * we may end up dropping a partition from memory store before getting it back, e.g. - * when the entirety of the RDD does not fit in memory. */ - val elements = new ArrayBuffer[Any] - elements ++= values - updatedBlocks ++= blockManager.put(key, elements, storageLevel, tellMaster = true) - elements.iterator.asInstanceOf[Iterator[T]] + * we may end up dropping a partition from memory store before getting it back. + * + * In addition, we must be careful to not unroll the entire partition in memory at once. + * Otherwise, we may cause an OOM exception if the JVM does not have enough space for this + * single partition. Instead, we unroll the values cautiously, potentially aborting and + * dropping the partition to disk if applicable. + */ + blockManager.memoryStore.unrollSafely(key, values, updatedBlocks) match { + case Left(arr) => + // We have successfully unrolled the entire partition, so cache it in memory + updatedBlocks ++= + blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) + arr.iterator.asInstanceOf[Iterator[T]] + case Right(it) => + // There is not enough space to cache this partition in memory + logWarning(s"Not enough space to cache partition $key in memory! " + + s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.") + val returnValues = it.asInstanceOf[Iterator[T]] + if (putLevel.useDisk) { + logWarning(s"Persisting partition $key to disk instead.") + val diskOnlyLevel = StorageLevel(useDisk = true, useMemory = false, + useOffHeap = false, deserialized = false, putLevel.replication) + putInBlockManager[T](key, returnValues, level, updatedBlocks, Some(diskOnlyLevel)) + } else { + returnValues + } + } } } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 09a60571238ea..f010c03223ef4 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -62,7 +63,8 @@ class ShuffleDependency[K, V, C]( val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false) + val mapSideCombine: Boolean = false, + val sortOrder: Option[SortOrder] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ec99648a8488a..52c018baa5f7b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -134,8 +134,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] var partition = 0 - if (rangeBounds.length < 1000) { - // If we have less than 100 partitions naive search + if (rangeBounds.length <= 128) { + // If we have less than 128 partitions naive search while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) { partition += 1 } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8052499ab7526..3e6addeaf04a8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1037,7 +1037,7 @@ class SparkContext(config: SparkConf) extends Logging { */ private[spark] def getCallSite(): CallSite = { Option(getLocalProperty("externalCallSite")) match { - case Some(callSite) => CallSite(callSite, long = "") + case Some(callSite) => CallSite(callSite, longForm = "") case None => Utils.getCallSite } } @@ -1059,11 +1059,12 @@ class SparkContext(config: SparkConf) extends Logging { } val callSite = getCallSite val cleanedFunc = clean(func) - logInfo("Starting job: " + callSite.short) + logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo( + "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } @@ -1144,11 +1145,12 @@ class SparkContext(config: SparkConf) extends Logging { evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { val callSite = getCallSite - logInfo("Starting job: " + callSite.short) + logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.get) - logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo( + "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 8f70744d804d9..6ee731b22c03c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -67,7 +67,7 @@ class SparkEnv ( val metricsSystem: MetricsSystem, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory used for shuffle in bytes + // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations // All accesses should be manually synchronized val shuffleMemoryMap = mutable.HashMap[Long, Long]() diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index df42d679b4699..8f0c5e78416c2 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -88,10 +89,7 @@ case class ExceptionFailure( stackTrace: Array[StackTraceElement], metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = { - val stackTraceString = if (stackTrace == null) "null" else stackTrace.mkString("\n") - s"$className ($description}\n$stackTraceString" - } + override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 1e0493c4855e0..8a5f8088a05ca 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -34,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{EmptyRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -112,6 +112,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def startTime: java.lang.Long = sc.startTime + /** The version of Spark on which this application is running. */ + def version: String = sc.version + /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: java.lang.Integer = sc.defaultParallelism @@ -132,6 +135,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) } + /** Get an RDD that has no partitions or elements. */ + def emptyRDD[T]: JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + JavaRDD.fromRDD(new EmptyRDD[T](sc)) + } + + /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T]): JavaRDD[T] = parallelize(list, sc.defaultParallelism) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 462e09466bfa6..d87783efd2d01 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -37,8 +37,8 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -private[spark] class PythonRDD[T: ClassTag]( - parent: RDD[T], +private[spark] class PythonRDD( + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -57,7 +57,10 @@ private[spark] class PythonRDD[T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val localdir = env.blockManager.diskBlockManager.localDirs.map( + f => f.getPath()).mkString(",") + val worker: Socket = env.createPythonWorker(pythonExec, + envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 76956f6a345d1..15fd30e65761d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -106,23 +106,23 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { * Actually get the broadcasted value. Concrete implementations of Broadcast class must * define their own way to get the value. */ - private[spark] def getValue(): T + protected def getValue(): T /** * Actually unpersist the broadcasted value on the executors. Concrete implementations of * Broadcast class must define their own logic to unpersist their own data. */ - private[spark] def doUnpersist(blocking: Boolean) + protected def doUnpersist(blocking: Boolean) /** * Actually destroy all data and metadata related to this broadcast variable. * Implementation of Broadcast class must define their own logic to destroy their own * state. */ - private[spark] def doDestroy(blocking: Boolean) + protected def doDestroy(blocking: Boolean) /** Check if this broadcast is valid. If not valid, exception is thrown. */ - private[spark] def assertValid() { + protected def assertValid() { if (!_isValid) { throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c88be6aba6901..8f8a0b11f9f2e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -39,7 +39,7 @@ private[spark] class BroadcastManager( synchronized { if (!initialized) { val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4f6cabaff2b99..487456467b23b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -40,9 +40,9 @@ private[spark] class HttpBroadcast[T: ClassTag]( @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def getValue = value_ + override protected def getValue() = value_ - val blockId = BroadcastBlockId(id) + private val blockId = BroadcastBlockId(id) /* * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster @@ -60,14 +60,14 @@ private[spark] class HttpBroadcast[T: ClassTag]( /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ - def doUnpersist(blocking: Boolean) { + override protected def doUnpersist(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) } /** * Remove all persisted state associated with this HTTP broadcast on the executors and driver. */ - def doDestroy(blocking: Boolean) { + override protected def doDestroy(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) } @@ -102,7 +102,7 @@ private[spark] class HttpBroadcast[T: ClassTag]( } } -private[spark] object HttpBroadcast extends Logging { +private[broadcast] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null private var compress: Boolean = false @@ -160,7 +160,7 @@ private[spark] object HttpBroadcast extends Logging { def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - def write(id: Long, value: Any) { + private def write(id: Long, value: Any) { val file = getFile(id) val out: OutputStream = { if (compress) { @@ -176,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging { files += file } - def read[T: ClassTag](id: Long): T = { + private def read[T: ClassTag](id: Long): T = { logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index d5a031e2bbb59..c7ef02d572a19 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -27,21 +27,21 @@ import org.apache.spark.{SecurityManager, SparkConf} * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) - def stop() { HttpBroadcast.stop() } + override def stop() { HttpBroadcast.stop() } /** * Remove all persisted state associated with the HTTP broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 734de37ba115d..86731b684f441 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,6 @@ package org.apache.spark.broadcast import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.reflect.ClassTag -import scala.math import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} @@ -49,19 +48,19 @@ private[spark] class TorrentBroadcast[T: ClassTag]( @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def getValue = value_ + override protected def getValue() = value_ - val broadcastId = BroadcastBlockId(id) + private val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } - @transient var arrayOfBlocks: Array[TorrentBlock] = null - @transient var totalBlocks = -1 - @transient var totalBytes = -1 - @transient var hasBlocks = 0 + @transient private var arrayOfBlocks: Array[TorrentBlock] = null + @transient private var totalBlocks = -1 + @transient private var totalBytes = -1 + @transient private var hasBlocks = 0 if (!isLocal) { sendBroadcast() @@ -70,7 +69,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ - def doUnpersist(blocking: Boolean) { + override protected def doUnpersist(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) } @@ -78,11 +77,11 @@ private[spark] class TorrentBroadcast[T: ClassTag]( * Remove all persisted state associated with this Torrent broadcast on the executors * and driver. */ - def doDestroy(blocking: Boolean) { + override protected def doDestroy(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - def sendBroadcast() { + private def sendBroadcast() { val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes @@ -159,7 +158,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( hasBlocks = 0 } - def receiveBroadcast(): Boolean = { + private def receiveBroadcast(): Boolean = { // Receive meta-info about the size of broadcast data, // the number of chunks it is divided into, etc. val metaId = BroadcastBlockId(id, "meta") @@ -211,7 +210,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( } -private[spark] object TorrentBroadcast extends Logging { +private[broadcast] object TorrentBroadcast extends Logging { private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null @@ -272,17 +271,19 @@ private[spark] object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + } } } -private[spark] case class TorrentBlock( +private[broadcast] case class TorrentBlock( blockID: Int, byteArray: Array[Byte]) extends Serializable -private[spark] case class TorrentInfo( +private[broadcast] case class TorrentInfo( @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 1de8396a0e17f..ad0f701d7a98f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -28,21 +28,21 @@ import org.apache.spark.{SecurityManager, SparkConf} */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) - def stop() { TorrentBroadcast.stop() } + override def stop() { TorrentBroadcast.stop() } /** * Remove all persisted state associated with the torrent broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver. * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b050dccb6d57f..3b5642b6caa36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -27,25 +27,39 @@ import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils /** - * Scala code behind the spark-submit script. The script handles setting up the classpath with - * relevant Spark dependencies and provides a layer over the different cluster managers and deploy - * modes that Spark supports. + * Main gateway of launching a Spark application. + * + * This program handles setting up the classpath with relevant Spark dependencies and provides + * a layer over the different cluster managers and deploy modes that Spark supports. */ object SparkSubmit { + + // Cluster managers private val YARN = 1 private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL - private var clusterManager: Int = LOCAL + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - /** - * Special primary resource names that represent shells rather than application jars. - */ + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + // Exposed for testing + private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var printStream: PrintStream = System.err + private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String) = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn() + } + def main(args: Array[String]) { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { @@ -55,88 +69,80 @@ object SparkSubmit { launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) } - // Exposed for testing - private[spark] var printStream: PrintStream = System.err - private[spark] var exitFn: () => Unit = () => System.exit(-1) - - private[spark] def printErrorAndExit(str: String) = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() - } - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - /** - * @return a tuple containing the arguments for the child, a list of classpath - * entries for the child, a list of system properties, a list of env vars - * and the main class for the child + * @return a tuple containing + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a list of system properties and env vars, and + * (4) the main class for the child */ private[spark] def createLaunchEnv(args: SparkSubmitArguments) : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { - if (args.master.startsWith("local")) { - clusterManager = LOCAL - } else if (args.master.startsWith("yarn")) { - clusterManager = YARN - } else if (args.master.startsWith("spark")) { - clusterManager = STANDALONE - } else if (args.master.startsWith("mesos")) { - clusterManager = MESOS - } else { - printErrorAndExit("Master must start with yarn, mesos, spark, or local") - } - // Because "yarn-cluster" and "yarn-client" encapsulate both the master - // and deploy mode, we have some logic to infer the master and deploy mode - // from each other if only one is specified, or exit early if they are at odds. - if (args.deployMode == null && - (args.master == "yarn-standalone" || args.master == "yarn-cluster")) { - args.deployMode = "cluster" - } - if (args.deployMode == "cluster" && args.master == "yarn-client") { - printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible") - } - if (args.deployMode == "client" && - (args.master == "yarn-standalone" || args.master == "yarn-cluster")) { - printErrorAndExit("Deploy mode \"client\" and master \"" + args.master - + "\" are not compatible") - } - if (args.deployMode == "cluster" && args.master.startsWith("yarn")) { - args.master = "yarn-cluster" - } - if (args.deployMode != "cluster" && args.master.startsWith("yarn")) { - args.master = "yarn-client" - } - - val deployOnCluster = Option(args.deployMode).getOrElse("client") == "cluster" - - val childClasspath = new ArrayBuffer[String]() + // Values to return val childArgs = new ArrayBuffer[String]() + val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() var childMainClass = "" - val isPython = args.isPython - val isYarnCluster = clusterManager == YARN && deployOnCluster + // Set the cluster manager + val clusterManager: Int = args.master match { + case m if m.startsWith("yarn") => YARN + case m if m.startsWith("spark") => STANDALONE + case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("local") => LOCAL + case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1 + } - // For mesos, only client mode is supported - if (clusterManager == MESOS && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + // Set the deploy mode; default is client mode + var deployMode: Int = args.deployMode match { + case "client" | null => CLIENT + case "cluster" => CLUSTER + case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 } - // For standalone, only client mode is supported - if (clusterManager == STANDALONE && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for standalone clusters.") + // Because "yarn-cluster" and "yarn-client" encapsulate both the master + // and deploy mode, we have some logic to infer the master and deploy mode + // from each other if only one is specified, or exit early if they are at odds. + if (clusterManager == YARN) { + if (args.master == "yarn-standalone") { + printWarning("\"yarn-standalone\" is deprecated. Use \"yarn-cluster\" instead.") + args.master = "yarn-cluster" + } + (args.master, args.deployMode) match { + case ("yarn-cluster", null) => + deployMode = CLUSTER + case ("yarn-cluster", "client") => + printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + case ("yarn-client", "cluster") => + printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + case (_, mode) => + args.master = "yarn-" + Option(mode).getOrElse("client") + } + + // Make sure YARN is included in our build if we're trying to use it + if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { + printErrorAndExit( + "Could not load YARN classes. " + + "This copy of Spark may not have been compiled with YARN support.") + } } - // For shells, only client mode is applicable - if (isShell(args.primaryResource) && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + // The following modes are not supported or applicable + (clusterManager, deployMode) match { + case (MESOS, CLUSTER) => + printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (STANDALONE, CLUSTER) => + printErrorAndExit("Cluster deploy mode is currently not supported for Standalone clusters.") + case (_, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") + case (_, CLUSTER) if isShell(args.primaryResource) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + case _ => } // If we're running a python app, set the main class to our specific python runner - if (isPython) { - if (deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for python.") - } + if (args.isPython) { if (args.primaryResource == PYSPARK_SHELL) { args.mainClass = "py4j.GatewayServer" args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") @@ -152,122 +158,120 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") } - // If we're deploying into YARN, use yarn.Client as a wrapper around the user class - if (!deployOnCluster) { - childMainClass = args.mainClass - if (isUserJar(args.primaryResource)) { - childClasspath += args.primaryResource - } - } else if (clusterManager == YARN) { - childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) - childArgs += ("--class", args.mainClass) - } - - // Make sure YARN is included in our build if we're trying to use it - if (clusterManager == YARN) { - if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { - printErrorAndExit("Could not load YARN classes. " + - "This copy of Spark may not have been compiled with YARN support.") - } - } - // Special flag to avoid deprecation warnings at the client sysProps("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( - OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"), - OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), - OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true, + + // All cluster managers + OptionAssigner(args.master, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.app.name"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + + // Standalone cluster only + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), + + // Yarn client only + OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), + OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), + OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + + // Yarn cluster only + OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), + OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), + OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), + OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), + OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), + OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), + OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), + OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), + + // Other options + OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraClassPath"), - OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true, + OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraJavaOptions"), - OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true, + OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraLibraryPath"), - OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"), - OptionAssigner(args.driverMemory, STANDALONE, true, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, true, clOption = "--cores"), - OptionAssigner(args.queue, YARN, true, clOption = "--queue"), - OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"), - OptionAssigner(args.numExecutors, YARN, false, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorMemory, YARN, true, clOption = "--executor-memory"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, false, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, CLIENT, sysProp = "spark.executor.memory"), - OptionAssigner(args.executorCores, YARN, true, clOption = "--executor-cores"), - OptionAssigner(args.executorCores, YARN, false, sysProp = "spark.executor.cores"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, false, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, CLIENT, sysProp = "spark.cores.max"), - OptionAssigner(args.files, YARN, false, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.files, YARN, true, clOption = "--files"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), - OptionAssigner(args.archives, YARN, false, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.archives, YARN, true, clOption = "--archives"), - OptionAssigner(args.jars, YARN, true, clOption = "--addJars"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, false, sysProp = "spark.jars") + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + sysProp = "spark.files") ) - // For client mode make any added jars immediately visible on the classpath - if (args.jars != null && !deployOnCluster) { - for (jar <- args.jars.split(",")) { - childClasspath += jar + // In client mode, launch the application main class directly + // In addition, add the main application jar and any added jars (if any) to the classpath + if (deployMode == CLIENT) { + childMainClass = args.mainClass + if (isUserJar(args.primaryResource)) { + childClasspath += args.primaryResource } + if (args.jars != null) { childClasspath ++= args.jars.split(",") } + if (args.childArgs != null) { childArgs ++= args.childArgs } } + // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { - if (opt.value != null && deployOnCluster == opt.deployOnCluster && + if (opt.value != null && + (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { - if (opt.clOption != null) { - childArgs += (opt.clOption, opt.value) - } - if (opt.sysProp != null) { - sysProps.put(opt.sysProp, opt.value) - } + if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } + if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } } } // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !isPython) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq()) + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + if (!isYarnCluster && !args.isPython) { + var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } sysProps.put("spark.jars", jars.mkString(",")) } - // Standalone cluster specific configurations - if (deployOnCluster && clusterManager == STANDALONE) { + // In standalone-cluster mode, use Client as a wrapper around the user class + if (clusterManager == STANDALONE && deployMode == CLUSTER) { + childMainClass = "org.apache.spark.deploy.Client" if (args.supervise) { childArgs += "--supervise" } - childMainClass = "org.apache.spark.deploy.Client" childArgs += "launch" childArgs += (args.master, args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } } - // Arguments to be passed to user program - if (args.childArgs != null) { - if (!deployOnCluster || clusterManager == STANDALONE) { - childArgs ++= args.childArgs - } else if (clusterManager == YARN) { - for (arg <- args.childArgs) { - childArgs += ("--arg", arg) - } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class + if (clusterManager == YARN && deployMode == CLUSTER) { + childMainClass = "org.apache.spark.deploy.yarn.Client" + childArgs += ("--jar", args.primaryResource) + childArgs += ("--class", args.mainClass) + if (args.childArgs != null) { + args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } } // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { - if (!sysProps.contains(k)) sysProps(k) = v + sysProps.getOrElseUpdate(k, v) } + // Spark properties included on command line take precedence + sysProps ++= args.sparkProperties + (childArgs, childClasspath, sysProps, childMainClass) } @@ -364,6 +368,6 @@ object SparkSubmit { private[spark] case class OptionAssigner( value: String, clusterManager: Int, - deployOnCluster: Boolean, + deployMode: Int, clOption: String = null, sysProp: String = null) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57655aa4c32b1..3ab67a43a3b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -55,6 +55,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) loadDefaults() @@ -177,6 +178,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile + | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -290,6 +292,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { jars = Utils.resolveURIs(value) parse(tail) + case ("--conf" | "-c") :: value :: tail => + value.split("=", 2).toSeq match { + case Seq(k, v) => sparkProperties(k) = v + case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") + } + parse(tail) + case ("--help" | "-h") :: tail => printUsageAndExit(0) @@ -349,6 +358,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working | directory of each executor. + | + | --conf PROP=VALUE Arbitrary Spark configuration property. | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a8c9ac072449f..01e7065c17b69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -169,7 +169,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui: SparkUI = if (renderUI) { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) + new SparkUI(conf, appSecManager, replayBus, appId, + HistoryServer.UI_PATH_PREFIX + s"/$appId") // Do not call ui.bind() to avoid creating a new server for each application } else { null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index a958c837c2ff6..d7a3e3f120e67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -75,7 +75,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Last Updated") private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val uiAddress = "/history/" + info.id + val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}" val startTime = UIUtils.formatDate(info.startTime) val endTime = UIUtils.formatDate(info.endTime) val duration = UIUtils.formatDuration(info.endTime - info.startTime) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 56b38ddfc9313..cacb9da8c947b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -114,7 +114,7 @@ class HistoryServer( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) val contextHandler = new ServletContextHandler - contextHandler.setContextPath("/history") + contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) contextHandler.addServlet(new ServletHolder(loaderServlet), "/*") attachHandler(contextHandler) } @@ -172,6 +172,8 @@ class HistoryServer( object HistoryServer extends Logging { private val conf = new SparkConf + val UI_PATH_PREFIX = "/history" + def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d9f8105992a10..21f8667819c44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.master +import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date @@ -30,11 +31,11 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.hadoop.fs.FileSystem import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI @@ -57,6 +58,7 @@ private[spark] class Master( def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") @@ -641,10 +643,7 @@ private[spark] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - if (!rebuildSparkUI(app)) { - // Avoid broken links if the UI is not reconstructed - app.desc.appUiUrl = "" - } + rebuildSparkUI(app) for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) @@ -666,29 +665,49 @@ private[spark] class Master( */ def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name - val eventLogDir = app.desc.eventLogDir.getOrElse { return false } + val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" + val eventLogDir = app.desc.eventLogDir.getOrElse { + // Event logging is not enabled for this application + app.desc.appUiUrl = notFoundBasePath + return false + } val fileSystem = Utils.getHadoopFileSystem(eventLogDir) val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) val eventLogPaths = eventLogInfo.logPaths val compressionCodec = eventLogInfo.compressionCodec - if (!eventLogPaths.isEmpty) { - try { - val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI( - new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) - replayBus.replay() - app.desc.appUiUrl = ui.basePath - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) - return true - } catch { - case e: Exception => - logError("Exception in replaying log for application %s (%s)".format(appName, app.id), e) - } - } else { - logWarning("Application %s (%s) has no valid logs: %s".format(appName, app.id, eventLogDir)) + + if (eventLogPaths.isEmpty) { + // Event logging is enabled for this application, but no event logs are found + val title = s"Application history not found (${app.id})" + var msg = s"No event logs found for application $appName in $eventLogDir." + logWarning(msg) + msg += " Did you specify the correct logging directory?" + msg = URLEncoder.encode(msg, "UTF-8") + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" + return false + } + + try { + val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) + val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", + HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + replayBus.replay() + appIdToUI(app.id) = ui + webUi.attachSparkUI(ui) + // Application UI is successfully rebuilt, so link the Master UI to it + app.desc.appUiUrl = ui.basePath + true + } catch { + case e: Exception => + // Relay exception message to application UI page + val title = s"Application history load error (${app.id})" + val exception = URLEncoder.encode(Utils.exceptionString(e), "UTF-8") + var msg = s"Exception in replaying log for application $appName!" + logError(msg, e) + msg = URLEncoder.encode(msg, "UTF-8") + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" + false } - false } /** Generate a new app ID given a app's submission date */ @@ -741,11 +760,16 @@ private[spark] class Master( case Some(driver) => logInfo(s"Removing driver: $driverId") drivers -= driver + if (completedDrivers.size >= RETAINED_DRIVERS) { + val toRemove = math.max(RETAINED_DRIVERS / 10, 1) + completedDrivers.trimStart(toRemove) + } completedDrivers += driver persistenceEngine.removeDriver(driver) driver.state = finalState driver.exception = exception driver.worker.foreach(w => w.removeDriver(driver)) + schedule() case None => logWarning(s"Asked to remove unknown driver: $driverId") } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 34fa1429c86de..4588c130ef439 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -28,7 +28,7 @@ import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorInfo -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala new file mode 100644 index 0000000000000..d8daff3e7fb9c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.net.URLDecoder +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[spark] class HistoryNotFoundPage(parent: MasterWebUI) + extends WebUIPage("history/not-found") { + + /** + * Render a page that conveys failure in loading application history. + * + * This accepts 3 HTTP parameters: + * msg = message to display to the user + * title = title of the page + * exception = detailed description of the exception in loading application history (if any) + * + * Parameters "msg" and "exception" are assumed to be UTF-8 encoded. + */ + def render(request: HttpServletRequest): Seq[Node] = { + val titleParam = request.getParameter("title") + val msgParam = request.getParameter("msg") + val exceptionParam = request.getParameter("exception") + + // If no parameters are specified, assume the user did not enable event logging + val defaultTitle = "Event logging is not enabled" + val defaultContent = +
+
+ No event logs were found for this application! To + enable event logging, + set spark.eventLog.enabled to true and + spark.eventLog.dir to the directory to which your + event logs are written. +
+
+ + val title = Option(titleParam).getOrElse(defaultTitle) + val content = Option(msgParam) + .map { msg => URLDecoder.decode(msg, "UTF-8") } + .map { msg => +
+
{msg}
+
++ + Option(exceptionParam) + .map { e => URLDecoder.decode(e, "UTF-8") } + .map { e =>
{e}
} + .getOrElse(Seq.empty) + }.getOrElse(defaultContent) + + UIUtils.basicSparkPage(content, title) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index a18b39fc95d64..16aa0493370dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.master.Master import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.AkkaUtils /** * Web UI server for the standalone master. @@ -38,6 +38,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) /** Initialize all components of the server. */ def initialize() { attachPage(new ApplicationPage(this)) + attachPage(new HistoryNotFoundPage(this)) attachPage(new MasterPage(this)) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) master.masterMetricsSystem.getServletHandlers.foreach(attachHandler) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8d31bd05fdbec..b455c9fcf4bd6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -71,7 +71,7 @@ private[spark] class CoarseGrainedExecutorBackend( val ser = SparkEnv.get.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) + executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 4d3ba11633bf5..3b69bc4ca4142 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -107,8 +107,9 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] - def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - val tr = new TaskRunner(context, taskId, serializedTask) + def launchTask( + context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { + val tr = new TaskRunner(context, taskId, taskName, serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } @@ -135,14 +136,15 @@ private[spark] class Executor( localDirs } - class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + class TaskRunner( + execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false @volatile private var task: Task[Any] = _ def kill(interruptThread: Boolean) { - logInfo("Executor is trying to kill task " + taskId) + logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { task.kill(interruptThread) @@ -154,7 +156,7 @@ private[spark] class Executor( SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() - logInfo("Running task ID " + taskId) + logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 @@ -207,25 +209,30 @@ private[spark] class Executor( val accumUpdates = Accumulators.values - val directResult = new DirectTaskResult(valueBytes, accumUpdates, - task.metrics.getOrElse(null)) + val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) - logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) - val serializedResult = { - if (serializedDirectResult.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - logInfo("Storing result for " + taskId + " in local BlockManager") + val resultSize = serializedDirectResult.limit + + // directSend = sending directly back to the driver + val (serializedResult, directSend) = { + if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - ser.serialize(new IndirectTaskResult[Any](blockId)) + (ser.serialize(new IndirectTaskResult[Any](blockId)), false) } else { - logInfo("Sending result for " + taskId + " directly to driver") - serializedDirectResult + (serializedDirectResult, true) } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - logInfo("Finished task ID " + taskId) + + if (directSend) { + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + } else { + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason @@ -233,7 +240,7 @@ private[spark] class Executor( } case _: TaskKilledException | _: InterruptedException if task.killed => { - logInfo("Executor killed task " + taskId) + logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } @@ -241,7 +248,7 @@ private[spark] class Executor( // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to // the default uncaught exception handler, which will terminate the Executor. - logError("Exception in task ID " + taskId, t) + logError(s"Exception in $taskName (TID $taskId)", t) val serviceTime = System.currentTimeMillis() - taskStart val metrics = attemptedTask.flatMap(t => t.metrics) @@ -249,7 +256,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) + val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid @@ -259,11 +266,13 @@ private[spark] class Executor( } } } finally { - // TODO: Unregister shuffle memory only for ResultTask + // Release memory used by this thread for shuffles val shuffleMemoryMap = env.shuffleMemoryMap shuffleMemoryMap.synchronized { shuffleMemoryMap.remove(Thread.currentThread().getId) } + // Release memory used by this thread for unrolling blocks + env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 2232e6237bf26..a42c8b43bbf7f 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -64,7 +64,7 @@ private[spark] class MesosExecutorBackend if (executor == null) { logError("Received launchTask but executor was null") } else { - executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer) + executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ac73288442a74..21fe643b8d71f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -75,7 +75,9 @@ class TaskMetrics extends Serializable { /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here */ - var shuffleReadMetrics: Option[ShuffleReadMetrics] = None + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None + + def shuffleReadMetrics = _shuffleReadMetrics /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected @@ -87,6 +89,21 @@ class TaskMetrics extends Serializable { * Storage statuses of any blocks that have been updated as a result of this task. */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None + + /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ + def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { + _shuffleReadMetrics match { + case Some(existingMetrics) => + existingMetrics.shuffleFinishTime = math.max( + existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) + existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime + existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched + existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched + existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead + case None => + _shuffleReadMetrics = Some(newMetrics) + } + } } private[spark] object TaskMetrics { @@ -131,7 +148,7 @@ class ShuffleReadMetrics extends Serializable { /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - var totalBlocksFetched: Int = _ + def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched /** * Number of remote blocks fetched in this shuffle by this task diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 4b0fe1ab82999..1b66218d86dd9 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -20,6 +20,7 @@ package org.apache.spark.io import java.io.{InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -55,7 +56,28 @@ private[spark] object CompressionCodec { ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[LZFCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName +} + + +/** + * :: DeveloperApi :: + * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. + * Block size can be configured by `spark.io.compression.lz4.block.size`. + * + * Note: The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. + */ +@DeveloperApi +class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { + + override def compressedOutputStream(s: OutputStream): OutputStream = { + val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + new LZ4BlockOutputStream(s, blockSize) + } + + override def compressedInputStream(s: InputStream): InputStream = new LZ4BlockInputStream(s) } @@ -81,7 +103,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by spark.io.compression.snappy.block.size. + * Block size can be configured by `spark.io.compression.snappy.block.size`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 8a1cdb812962e..566e8a4aaa1d2 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -62,13 +62,15 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, conf.getInt("spark.core.connection.handler.threads.min", 20), conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-message-executor")) private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-read-write-executor")) // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -76,7 +78,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, conf.getInt("spark.core.connection.connect.threads.min", 1), conf.getInt("spark.core.connection.connect.threads.max", 8), conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-connect-executor")) private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 5951865e56c9d..6388ef82cc5db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -25,7 +25,7 @@ import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep] */ @DeveloperApi class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. // CoGroupValue is the intermediate state of each value before being merged in compute. - private type CoGroup = ArrayBuffer[Any] + private type CoGroup = CompactBuffer[Any] private type CoGroupValue = (Any, Int) // Int is dependency number - private type CoGroupCombiner = Seq[CoGroup] + private type CoGroupCombiner = Array[CoGroup] private var serializer: Option[Serializer] = None @@ -114,7 +114,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner: Some[Partitioner] = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] @@ -150,18 +150,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: getCombiner(kv._1)(depNum) += kv._2 } } - new InterruptibleIterator(context, map.iterator) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } else { val map = createExternalMap(numRdds) - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - map.insert(kv._1, new CoGroupValue(kv._2, depNum)) - } + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled - new InterruptibleIterator(context, map.iterator) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } } @@ -170,17 +169,22 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) - value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner(value._2) += value._1 newCombiner } val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = (combiner, value) => { - value match { case (v, depNum) => combiner(depNum) += v } + combiner(value._2) += value._1 combiner } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { - combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + var depNum = 0 + while (depNum < numRdds) { + combiner1(depNum) ++= combiner2(depNum) + depNum += 1 + } + combiner1 } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index c45b759f007cc..e7221e3032c11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -258,7 +258,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc val pgroup = PartitionGroup(nxt_replica) groupArr += pgroup addPartToPGroup(nxt_part, pgroup) - groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple + groupHash.put(nxt_replica, ArrayBuffer(pgroup)) // list in case we have multiple numCreated += 1 } } @@ -267,7 +267,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc var (nxt_replica, nxt_part) = rotIt.next() val pgroup = PartitionGroup(nxt_replica) groupArr += pgroup - groupHash.get(nxt_replica).get += pgroup + groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup var tries = 0 while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part nxt_part = rotIt.next()._2 diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 041028514399b..e521612ffc27c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -140,8 +140,8 @@ class HadoopRDD[K, V]( // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456) - conf.synchronized { + // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { val newJobConf = new JobConf(conf) initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) @@ -246,6 +246,9 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD { + /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala index 2bc47eb9fcd74..a60952eee5901 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala @@ -28,6 +28,6 @@ class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) override val partitioner = firstParent[Product2[K, U]].partitioner override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = { - firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) } + firstParent[Product2[K, V]].iterator(split, context).map { pair => (pair._1, f(pair._2)) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index f1f4b4324edfd..afd7075f686b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -57,14 +57,13 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, */ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering) - shuffled.mapPartitions(iter => { - val buf = iter.toArray - if (ascending) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } - }, preservesPartitioning = true) + new ShuffledRDD[K, V, V, P](self, part) + .setKeyOrdering(ordering) + .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) } } + +private[spark] object SortOrder extends Enumeration { + type SortOrder = Value + val ASCENDING, DESCENDING = Value +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index f5153d31e7b70..026a569e7ec53 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -123,7 +123,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) zeroBuffer.get(zeroArray) lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) } @@ -169,7 +169,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // When deserializing, use a lazy val to create just one instance of the serializer per task lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } @@ -247,22 +247,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("reduceByKeyLocally() does not support array keys") } - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { + val reducePartition = (iter: Iterator[(K, V)]) => { val map = new JHashMap[K, V] - iter.foreach { case (k, v) => - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) + iter.foreach { pair => + val old = map.get(pair._1) + map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) } Iterator(map) - } + } : Iterator[JHashMap[K, V]] - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - m2.foreach { case (k, v) => - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) + val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { + m2.foreach { pair => + val old = m1.get(pair._1) + m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) } m1 - } + } : JHashMap[K, V] self.mapPartitions(reducePartition).reduce(mergeMaps) } @@ -386,29 +386,29 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance. + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = { // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 - val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) - bufs.mapValues(_.toIterable) + val createCombiner = (v: V) => CompactBuffer(v) + val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v + val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 + val bufs = combineByKey[CompactBuffer[V]]( + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + bufs.asInstanceOf[RDD[(K, Iterable[V])]] } /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance. + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = { groupByKey(new HashPartitioner(numPartitions)) @@ -434,9 +434,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs; w <- ws) yield (v, w) - } + this.cogroup(other, partitioner).flatMapValues( pair => + for (v <- pair._1; w <- pair._2) yield (v, w) + ) } /** @@ -446,11 +446,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * partition the output RDD. */ def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (ws.isEmpty) { - vs.map(v => (v, None)) + this.cogroup(other, partitioner).flatMapValues { pair => + if (pair._2.isEmpty) { + pair._1.map(v => (v, None)) } else { - for (v <- vs; w <- ws) yield (v, Some(w)) + for (v <- pair._1; w <- pair._2) yield (v, Some(w)) } } } @@ -463,11 +463,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Option[V], W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (vs.isEmpty) { - ws.map(w => (None, w)) + this.cogroup(other, partitioner).flatMapValues { pair => + if (pair._1.isEmpty) { + pair._2.map(w => (None, w)) } else { - for (v <- vs; w <- ws) yield (Some(v), w) + for (v <- pair._1; w <- pair._2) yield (Some(v), w) } } } @@ -495,9 +495,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance, + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(): RDD[(K, Iterable[V])] = { groupByKey(defaultPartitioner(self)) @@ -571,7 +571,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val data = self.collect() val map = new mutable.HashMap[K, V] map.sizeHint(data.length) - data.foreach { case (k, v) => map.put(k, v) } + data.foreach { pair => map.put(pair._1, pair._2) } map } @@ -607,11 +607,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) - cg.mapValues { case Seq(vs, w1s, w2s, w3s) => - (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]], - w3s.asInstanceOf[Seq[W3]]) + cg.mapValues { case Array(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Iterable[V]], + w1s.asInstanceOf[Iterable[W1]], + w2s.asInstanceOf[Iterable[W2]], + w3s.asInstanceOf[Iterable[W3]]) } } @@ -625,8 +625,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - cg.mapValues { case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + cg.mapValues { case Array(vs, w1s) => + (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]]) } } @@ -640,10 +640,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - cg.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]]) + cg.mapValues { case Array(vs, w1s, w2s) => + (vs.asInstanceOf[Iterable[V]], + w1s.asInstanceOf[Iterable[W1]], + w2s.asInstanceOf[Iterable[W2]]) } } @@ -746,14 +746,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.partitioner match { case Some(p) => val index = p.getPartition(key) - def process(it: Iterator[(K, V)]): Seq[V] = { + val process = (it: Iterator[(K, V)]) => { val buf = new ArrayBuffer[V] - for ((k, v) <- it if k == key) { - buf += v + for (pair <- it if pair._1 == key) { + buf += pair._2 } buf - } - val res = self.context.runJob(self, process _, Array(index), false) + } : Seq[V] + val res = self.context.runJob(self, process, Array(index), false) res(0) case None => self.filter(_._1 == key).map(_._2).collect() @@ -876,7 +876,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { + val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt @@ -894,22 +894,21 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { while (iter.hasNext) { - val (k, v) = iter.next() - writer.write(k, v) + val pair = iter.next() + writer.write(pair._1, pair._2) } - } - finally { + } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) - return 1 - } + 1 + } : Int val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) - self.context.runJob(self, writeShard _) + self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) } @@ -948,7 +947,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = new SparkHadoopWriter(hadoopConf) writer.preSetup() - def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { + val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt @@ -957,19 +956,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.open() try { var count = 0 - while(iter.hasNext) { + while (iter.hasNext) { val record = iter.next() count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } - } - finally { + } finally { writer.close() } writer.commit() } - self.context.runJob(self, writeToFile _) + self.context.runJob(self, writeToFile) writer.commitJob() } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index b5b8a5706deb3..a637d6f15b7e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) * * @param prev RDD to be sampled * @param sampler a random sampler + * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD * @param seed random seed * @tparam T input RDD item type * @tparam U sampled RDD item type @@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], + @transient preservesPartitioning: Boolean, @transient seed: Long = Utils.random.nextLong) extends RDD[U](prev) { + @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None + override def getPartitions: Array[Partition] = { val random = new Random(seed) firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong())) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4e841bc992bff..a6abc49c5359e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -328,7 +328,7 @@ abstract class RDD[T: ClassTag]( : RDD[T] = { if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ - def distributePartition(index: Int, items: Iterator[T]): Iterator[(Int, T)] = { + val distributePartition = (index: Int, items: Iterator[T]) => { var position = (new Random(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner @@ -336,7 +336,7 @@ abstract class RDD[T: ClassTag]( position = position + 1 (position, t) } - } + } : Iterator[(Int, T)] // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( @@ -354,11 +354,11 @@ abstract class RDD[T: ClassTag]( def sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T] = { - require(fraction >= 0.0, "Invalid fraction value: " + fraction) + require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { - new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) } else { - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed) + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) } } @@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed) + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed) }.toArray } @@ -509,6 +509,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = groupBy[K](f, defaultPartitioner(this)) @@ -516,6 +520,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = groupBy(f, new HashPartitioner(numPartitions)) @@ -523,6 +531,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = { @@ -574,6 +586,9 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD by applying a function to each partition of this RDD. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitions[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { @@ -584,6 +599,9 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { @@ -595,6 +613,9 @@ abstract class RDD[T: ClassTag]( * :: DeveloperApi :: * Return a new RDD by applying a function to each partition of this RDD. This is a variant of * mapPartitions that also passes the TaskContext into the closure. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi def mapPartitionsWithContext[U: ClassTag]( @@ -677,7 +698,7 @@ abstract class RDD[T: ClassTag]( * a map on the other). */ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { - zipPartitions(other, true) { (thisIter, otherIter) => + zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { def hasNext = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true @@ -733,14 +754,16 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** @@ -907,19 +930,19 @@ abstract class RDD[T: ClassTag]( throw new SparkException("countByValue() does not support arrays") } // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OpenHashMap[T,Long]] = { + val countPartition = (iter: Iterator[T]) => { val map = new OpenHashMap[T,Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } Iterator(map) - } - def mergeMaps(m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]): OpenHashMap[T,Long] = { + }: Iterator[OpenHashMap[T,Long]] + val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => { m2.foreach { case (key, value) => m1.changeValue(key, value, _ + value) } m1 - } + }: OpenHashMap[T,Long] val myResult = mapPartitions(countPartition).reduce(mergeMaps) // Convert to a Scala mutable map val mutableResult = scala.collection.mutable.Map[T,Long]() @@ -1202,7 +1225,7 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = Utils.getCallSite - private[spark] def getCreationSite: String = Option(creationSite).map(_.short).getOrElse("") + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] @@ -1257,11 +1280,55 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { - def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ - rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + // Apply a different rule to the last child + def debugChildren(rdd: RDD[_], prefix: String): Seq[String] = { + val len = rdd.dependencies.length + len match { + case 0 => Seq.empty + case 1 => + val d = rdd.dependencies.head + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + case _ => + val frontDeps = rdd.dependencies.take(len - 1) + val frontDepStrings = frontDeps.flatMap( + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + + val lastDep = rdd.dependencies.last + val lastDepStrings = + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + + (frontDepStrings ++ lastDepStrings) + } + } + // The first RDD in the dependency stack has no parents, so no need for a +- + def firstDebugString(rdd: RDD[_]): Seq[String] = { + val partitionStr = "(" + rdd.partitions.size + ")" + val leftOffset = (partitionStr.length - 1) / 2 + val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) + Seq(partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + } + def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { + val partitionStr = "(" + rdd.partitions.size + ")" + val leftOffset = (partitionStr.length - 1) / 2 + val thisPrefix = prefix.replaceAll("\\|\\s+$", "") + val nextPrefix = ( + thisPrefix + + (if (isLastChild) " " else "| ") + + (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset))) + Seq(thisPrefix + "+-" + partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + } + def debugString(rdd: RDD[_], + prefix: String = "", + isShuffle: Boolean = true, + isLastChild: Boolean = false): Seq[String] = { + if (isShuffle) { + shuffleDebugString(rdd, prefix, isLastChild) + } + else { + Seq(prefix + rdd) ++ debugChildren(rdd, prefix) + } } - debugString(this).mkString("\n") + firstDebugString(this).mkString("\n") } override def toString: String = "%s%s[%d] at %s".format( diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index bf02f68d0d3d3..da4a8c3dc22b1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -51,6 +52,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false + private var sortOrder: Option[SortOrder] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) @@ -75,8 +78,15 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( this } + /** Set sort order for RDD's sorting. */ + def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { + this.sortOrder = Option(sortOrder) + this + } + override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) + List(new ShuffleDependency(prev, part, serializer, + keyOrdering, aggregator, mapSideCombine, sortOrder)) } override val partitioner = Some(part) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f72bfde572c96..dc6142ab79d03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -85,12 +85,9 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] - private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob] - private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo] // Stages we need to run whose parents aren't done private[scheduler] val waitingStages = new HashSet[Stage] @@ -101,9 +98,6 @@ class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - // Missing tasks from each stage - private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] - private[scheduler] val activeJobs = new HashSet[ActiveJob] // Contains the locations that each RDD's partitions are cached on @@ -223,7 +217,6 @@ class DAGScheduler( new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) - stageToInfos(stage) = StageInfo.fromStage(stage) stage } @@ -315,13 +308,12 @@ class DAGScheduler( */ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { def updateJobIdStageIdMapsList(stages: List[Stage]) { - if (!stages.isEmpty) { + if (stages.nonEmpty) { val s = stages.head - stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + s.jobIds += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id - val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter(p => - !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parents: List[Stage] = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) } updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -333,16 +325,15 @@ class DAGScheduler( * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. * * @param job The job whose state to cleanup. - * @param resultStage Specifies the result stage for the job; if set to None, this method - * searches resultStagesToJob to find and cleanup the appropriate result stage. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) } else { - stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { - case (stageId, jobSet) => + stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { + case (stageId, stage) => + val jobSet = stage.jobIds if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" @@ -355,14 +346,9 @@ class DAGScheduler( logDebug("Removing running stage %d".format(stageId)) runningStages -= stage } - stageToInfos -= stage for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { shuffleToMapStage.remove(k) } - if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) { - logDebug("Removing pending status for stage %d".format(stageId)) - } - pendingTasks -= stage if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) waitingStages -= stage @@ -374,7 +360,6 @@ class DAGScheduler( } // data structures based on StageId stageIdToStage -= stageId - stageIdToJobIds -= stageId ShuffleMapTask.removeStage(stageId) ResultTask.removeStage(stageId) @@ -393,19 +378,7 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - - if (resultStage.isEmpty) { - // Clean up result stages. - val resultStagesForJob = resultStageToJob.keySet.filter( - stage => resultStageToJob(stage).jobId == job.jobId) - if (resultStagesForJob.size != 1) { - logWarning( - s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") - } - resultStageToJob --= resultStagesForJob - } else { - resultStageToJob -= resultStage.get - } + job.finalStage.resultOfJob = None } /** @@ -455,7 +428,7 @@ class DAGScheduler( waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite.short) + logInfo("Failed to run " + callSite.shortForm) throw exception } } @@ -591,9 +564,10 @@ class DAGScheduler( job.listener.jobFailed(exception) } finally { val s = job.finalStage - stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, - stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through - stageToInfos -= s // completion events or stage abort + // clean up data structures that were populated for a local job, + // but that won't get cleaned up via the normal paths through + // completion events or stage abort + stageIdToStage -= s.id jobIdToStageIds -= job.jobId listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult)) } @@ -605,12 +579,8 @@ class DAGScheduler( // That should take care of at least part of the priority inversion problem with // cross-job dependencies. private def activeJobForStage(stage: Stage): Option[Int] = { - if (stageIdToJobIds.contains(stage.id)) { - val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted - jobsThatUseStage.find(jobIdToActiveJob.contains) - } else { - None - } + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted + jobsThatUseStage.find(jobIdToActiveJob.contains) } private[scheduler] def handleJobGroupCancelled(groupId: String) { @@ -642,9 +612,8 @@ class DAGScheduler( // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" runningStages.foreach { stage => - val info = stageToInfos(stage) - info.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(info)) + stage.info.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageCompleted(stage.info)) } listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } @@ -679,7 +648,7 @@ class DAGScheduler( val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.short, partitions.length, allowLocal)) + job.jobId, callSite.shortForm, partitions.length, allowLocal)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) @@ -690,7 +659,7 @@ class DAGScheduler( } else { jobIdToActiveJob(jobId) = job activeJobs += job - resultStageToJob(finalStage) = job + finalStage.resultOfJob = Some(job) listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) @@ -710,7 +679,6 @@ class DAGScheduler( if (missing == Nil) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) - runningStages += stage } else { for (parent <- missing) { submitStage(parent) @@ -728,8 +696,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - myPending.clear() + stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { @@ -738,7 +705,7 @@ class DAGScheduler( } } else { // This is a final stage; figure out its job's missing partitions - val job = resultStageToJob(stage) + val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { val partition = job.partitions(id) val locs = getPreferredLocs(stage.rdd, partition) @@ -753,11 +720,14 @@ class DAGScheduler( null } - // must be run listener before possible NotSerializableException - // should be "StageSubmitted" first and then "JobEnded" - listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties)) - if (tasks.size > 0) { + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and @@ -776,11 +746,11 @@ class DAGScheduler( } logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - myPending ++= tasks - logDebug("New pending tasks: " + myPending) + stage.pendingTasks ++= tasks + logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stageToInfos(stage).submissionTime = Some(clock.getTime()) + stage.info.submissionTime = Some(clock.getTime()) } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -805,26 +775,25 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stageToInfos(stage).submissionTime match { + val serviceTime = stage.info.submissionTime match { case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) case _ => "Unknown" } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stageToInfos(stage).completionTime = Some(clock.getTime()) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + stage.info.completionTime = Some(clock.getTime()) + listenerBus.post(SparkListenerStageCompleted(stage.info)) runningStages -= stage } event.reason match { case Success => - logInfo("Completed " + task) if (event.accumUpdates != null) { // TODO: fail the stage if the accumulator update fails... Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted } - pendingTasks(stage) -= task + stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - resultStageToJob.get(stage) match { + stage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true @@ -832,7 +801,7 @@ class DAGScheduler( // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { markStageAsFinished(stage) - cleanupStateForJobAndIndependentStages(job, Some(stage)) + cleanupStateForJobAndIndependentStages(job) listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) } @@ -859,7 +828,7 @@ class DAGScheduler( } else { stage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) { + if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { markStageAsFinished(stage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -908,7 +877,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - pendingTasks(stage) += task + stage.pendingTasks += task case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => // Mark the stage that the reducer was in as unrunnable @@ -993,13 +962,14 @@ class DAGScheduler( } private[scheduler] def handleStageCancellation(stageId: Int) { - if (stageIdToJobIds.contains(stageId)) { - val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray - jobsThatUseStage.foreach(jobId => { - handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId)) - }) - } else { - logInfo("No active jobs to kill for Stage " + stageId) + stageIdToStage.get(stageId) match { + case Some(stage) => + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray + jobsThatUseStage.foreach { jobId => + handleJobCancellation(jobId, s"because Stage $stageId was cancelled") + } + case None => + logInfo("No active jobs to kill for Stage " + stageId) } submitWaitingStages() } @@ -1008,8 +978,8 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages(jobIdToActiveJob(jobId), - "Job %d cancelled %s".format(jobId, reason), None) + failJobAndIndependentStages( + jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason)) } submitWaitingStages() } @@ -1023,26 +993,21 @@ class DAGScheduler( // Skip all the actions if the stage has been removed. return } - val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - stageToInfos(failedStage).completionTime = Some(clock.getTime()) - for (resultStage <- dependentStages) { - val job = resultStageToJob(resultStage) - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", - Some(resultStage)) + val dependentJobs: Seq[ActiveJob] = + activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq + failedStage.info.completionTime = Some(clock.getTime()) + for (job <- dependentJobs) { + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } - if (dependentStages.isEmpty) { + if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") } } /** * Fails a job and all stages that are only used by that job, and cleans up relevant state. - * - * @param resultStage The result stage for the job, if known. Used to cleanup state for the job - * slightly more efficiently than when not specified. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, - resultStage: Option[Stage]) { + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1056,7 +1021,7 @@ class DAGScheduler( logError("No stages registered for job " + job.jobId) } stages.foreach { stageId => - val jobsForStage = stageIdToJobIds.get(stageId) + val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds) if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" @@ -1070,9 +1035,8 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - val stageInfo = stageToInfos(stage) - stageInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + stage.info.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stage.info)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1085,7 +1049,7 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) - cleanupStateForJobAndIndependentStages(job, resultStage) + cleanupStateForJobAndIndependentStages(job) listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a90b0d475c04e..ae6ca9f4e7bf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -63,6 +63,13 @@ private[spark] class EventLoggingListener( // For testing. Keep track of all JSON serialized events that have been logged. private[scheduler] val loggedEvents = new ArrayBuffer[JValue] + /** + * Return only the unique application directory without the base directory. + */ + def getApplicationLogDir(): String = { + name + } + /** * Begin logging events. * If compression is used, log a file that indicates which compression library is used. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 8ec482a6f6d9c..800905413d145 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashSet + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId @@ -56,8 +58,22 @@ private[spark] class Stage( val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 + + /** Set of jobs that this stage belongs to. */ + val jobIds = new HashSet[Int] + + /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ + var resultOfJob: Option[ActiveJob] = None + var pendingTasks = new HashSet[Task[_]] + private var nextAttemptId = 0 + val name = callSite.shortForm + val details = callSite.longForm + + /** Pointer to the [StageInfo] object, set by DAGScheduler. */ + var info: StageInfo = StageInfo.fromStage(this) + def isAvailable: Boolean = { if (!isShuffleMap) { true @@ -108,9 +124,6 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - val name = callSite.short - val details = callSite.long - override def toString = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 29de0453ac19a..ca0595f35143e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -84,6 +84,8 @@ class TaskInfo( } } + def id: String = s"$index.$attempt" + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished task") diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4b6d6da5a6e61..be3673c48eda8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -88,6 +88,8 @@ private[spark] class TaskSchedulerImpl( // in turn is used to decide when we can attain data locality on a given host private val executorsByHost = new HashMap[String, HashSet[String]] + protected val hostsByRack = new HashMap[String, HashSet[String]] + private val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into @@ -223,6 +225,9 @@ private[spark] class TaskSchedulerImpl( executorAdded(o.executorId, o.host) newExecAvail = true } + for (rack <- getRackForHost(o.host)) { + hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host + } } // Randomly shuffle offers to avoid always placing tasks on the same set of workers. @@ -418,6 +423,12 @@ private[spark] class TaskSchedulerImpl( execs -= executorId if (execs.isEmpty) { executorsByHost -= host + for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) { + hosts -= host + if (hosts.isEmpty) { + hostsByRack -= rack + } + } } executorIdToHost -= executorId rootPool.executorLost(executorId, host) @@ -435,6 +446,10 @@ private[spark] class TaskSchedulerImpl( executorsByHost.contains(host) } + def hasHostAliveOnRack(rack: String): Boolean = synchronized { + hostsByRack.contains(rack) + } + def isExecutorAlive(execId: String): Boolean = synchronized { activeExecutorIds.contains(execId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 059cc9085a2e7..8b5e8cb802a45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -26,8 +26,7 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, - SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{Clock, SystemClock} @@ -52,8 +51,8 @@ private[spark] class TaskSetManager( val taskSet: TaskSet, val maxTaskFailures: Int, clock: Clock = SystemClock) - extends Schedulable with Logging -{ + extends Schedulable with Logging { + val conf = sched.sc.conf /* @@ -191,7 +190,9 @@ private[spark] class TaskSetManager( addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - hadAliveLocations = true + if(sched.hasHostAliveOnRack(rack)){ + hadAliveLocations = true + } } } @@ -401,14 +402,11 @@ private[spark] class TaskSetManager( // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) // Do various bookkeeping copiesRunning(index) += 1 val attemptNum = taskAttempts(index).size - val info = new TaskInfo( - taskId, index, attemptNum + 1, curTime, execId, host, taskLocality, speculative) + val info = new TaskInfo(taskId, index, attemptNum, curTime, + execId, host, taskLocality, speculative) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling @@ -427,11 +425,15 @@ private[spark] class TaskSetManager( s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") } - val timeTaken = clock.getTime() - startTime addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) + + // We used to log the time it takes to serialize the task, but task size is already + // a good proxy to task serialization time. + // val timeTaken = clock.getTime() - startTime + val taskName = s"task ${info.id} in stage ${taskSet.id}" + logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( + taskName, taskId, host, taskLocality, serializedTask.limit)) + sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) } @@ -490,19 +492,19 @@ private[spark] class TaskSetManager( info.markSuccessful() removeRunningTask(tid) sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { tasksSuccessful += 1 - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksSuccessful, numTasks)) + logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( + info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks)) // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { isZombie = true } } else { - logInfo("Ignorning task-finished event for TID " + tid + " because task " + - index + " has already completed successfully") + logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + + " because task " + index + " has already completed successfully") } failedExecutors.remove(index) maybeFinishTaskSet() @@ -521,14 +523,13 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - if (!isZombie) { - logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - } var taskMetrics : TaskMetrics = null - var failureReason: String = null + + val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + + reason.asInstanceOf[TaskFailedReason].toErrorString reason match { case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + logWarning(failureReason) if (!successful(index)) { successful(index) = true tasksSuccessful += 1 @@ -536,23 +537,17 @@ private[spark] class TaskSetManager( // Not adding to failed executors for FetchFailed. isZombie = true - case TaskKilled => - // Not adding to failed executors for TaskKilled. - logWarning("Task %d was killed.".format(tid)) - case ef: ExceptionFailure => - taskMetrics = ef.metrics.getOrElse(null) - if (ef.className == classOf[NotSerializableException].getName()) { + taskMetrics = ef.metrics.orNull + if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. - logError("Task %s:%s had a not serializable result: %s; not retrying".format( - taskSet.id, index, ef.description)) - abort("Task %s:%s had a not serializable result: %s".format( - taskSet.id, index, ef.description)) + logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying" + .format(info.id, taskSet.id, tid, ef.description)) + abort("Task %s in stage %s (TID %d) had a not serializable result: %s".format( + info.id, taskSet.id, tid, ef.description)) return } val key = ef.description - failureReason = "Exception failure in TID %s on host %s: %s\n%s".format( - tid, info.host, ef.description, ef.stackTrace.map(" " + _).mkString("\n")) val now = clock.getTime() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { @@ -570,19 +565,18 @@ private[spark] class TaskSetManager( } } if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) + logWarning(failureReason) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + logInfo( + s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } - case TaskResultLost => - failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) - case _ => - failureReason = "TID %s on host %s failed for unknown reason".format(tid, info.host) + case e: TaskEndReason => + logError("Unknown TaskEndReason: " + e) } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -593,10 +587,10 @@ private[spark] class TaskSetManager( assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { - logError("Task %s:%d failed %d times; aborting job".format( - taskSet.id, index, maxTaskFailures)) - abort("Task %s:%d failed %d times, most recent failure: %s\nDriver stacktrace:".format( - taskSet.id, index, maxTaskFailures, failureReason)) + logError("Task %d in stage %s failed %d times; aborting job".format( + index, taskSet.id, maxTaskFailures)) + abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" + .format(index, taskSet.id, maxTaskFailures, failureReason)) return } } @@ -709,8 +703,8 @@ private[spark] class TaskSetManager( if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) + "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms" + .format(index, taskSet.id, info.host, threshold)) speculatableTasks += index foundTasks = true } @@ -748,7 +742,8 @@ private[spark] class TaskSetManager( pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && + pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL } levels += ANY @@ -761,7 +756,8 @@ private[spark] class TaskSetManager( def newLocAvail(index: Int): Boolean = { for (loc <- tasks(index).preferredLocations) { if (sched.hasExecutorsAliveOnHost(loc.host) || - sched.getRackForHost(loc.host).isDefined) { + (sched.getRackForHost(loc.host).isDefined && + sched.hasHostAliveOnRack(sched.getRackForHost(loc.host).get))) { return true } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 318e16552201c..6abf6d930c155 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,4 +66,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String) + extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0f5545e2ed65f..9f085eef46720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.ui.JettyUtils /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -136,6 +137,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A removeExecutor(executorId, reason) sender ! true + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + sender ! true case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -276,6 +280,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } false } + + // Add filters to the SparkUI + def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + conf.set(s"spark.$filterName.params", filterParams) + JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf) + } + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 9b95ccca0443e..5b897597fa285 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -57,7 +57,7 @@ private[spark] class LocalActor( case StatusUpdate(taskId, state, serializedData) => scheduler.statusUpdate(taskId, state, serializedData) if (TaskState.isFinished(state)) { - freeCores += 1 + freeCores += scheduler.CPUS_PER_TASK reviveOffers() } @@ -68,8 +68,8 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { - freeCores -= 1 - executor.launchTask(executorBackend, task.taskId, task.serializedTask) + freeCores -= scheduler.CPUS_PER_TASK + executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 0a7e1ec539679..a7fa057ee05f7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -108,7 +108,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize */ @DeveloperApi class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { - private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) + private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1ce4243194798..fa79b25759153 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,6 +31,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.CompactBuffer import scala.reflect.ClassTag @@ -48,6 +49,7 @@ class KryoSerializer(conf: SparkConf) private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) + private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") def newKryoOutput() = new KryoOutput(bufferSize) @@ -55,6 +57,7 @@ class KryoSerializer(conf: SparkConf) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() + kryo.setRegistrationRequired(registrationRequired) val classLoader = Thread.currentThread.getContextClassLoader // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. @@ -183,9 +186,11 @@ private[serializer] object KryoSerializer { classOf[GotBlock], classOf[GetBlock], classOf[MapStatus], + classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], - classOf[BoundedPriorityQueue[_]] + classOf[BoundedPriorityQueue[_]], + classOf[SparkConf] ) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a932455776e34..99788828981c7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -81,10 +81,9 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d45258c0a492b..76cdb8f4f8e8a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -38,7 +39,7 @@ class HashShuffleReader[K, C]( val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, Serializer.getSerializer(dep.serializer)) - if (dep.aggregator.isDefined) { + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) } else { @@ -49,6 +50,17 @@ class HashShuffleReader[K, C]( } else { iter } + + val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { + val buf = aggregatedIter.toArray + if (sortOrder == SortOrder.ASCENDING) { + buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator + } else { + buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator + } + } + + sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 408a797088059..69905a960a2ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -46,7 +46,6 @@ import org.apache.spark.util.Utils private[storage] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() - def totalBlocks: Int def numLocalBlocks: Int def numRemoteBlocks: Int def fetchWaitTime: Long @@ -180,9 +179,9 @@ object BlockFetcherIterator { if (curRequestSize >= targetRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 curBlocks = new ArrayBuffer[(BlockId, Long)] logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 } } // Add in the final request @@ -192,7 +191,7 @@ object BlockFetcherIterator { } } logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - totalBlocks + " blocks") + (numLocal + numRemote) + " blocks") remoteRequests } @@ -235,7 +234,6 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def totalBlocks: Int = numLocal + numRemote override def numLocalBlocks: Int = numLocal override def numRemoteBlocks: Int = numRemote override def fetchWaitTime: Long = _fetchWaitTime diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0db0a5bc7341b..d746526639e58 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -38,7 +38,7 @@ import org.apache.spark.util._ private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues -private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends BlockValues +private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -71,9 +71,9 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private var tachyonInitialized = false - private[storage] val memoryStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore = new DiskStore(this, diskBlockManager) - private[storage] lazy val tachyonStore: TachyonStore = { + private[spark] val memoryStore = new MemoryStore(this, maxMemory) + private[spark] val diskStore = new DiskStore(this, diskBlockManager) + private[spark] lazy val tachyonStore: TachyonStore = { val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon") val appFolderName = conf.get("spark.tachyonStore.folderName") val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}" @@ -463,16 +463,17 @@ private[spark] class BlockManager( val values = dataDeserialize(blockId, bytes) if (level.deserialized) { // Cache the values before returning them - // TODO: Consider creating a putValues that also takes in a iterator? - val valuesBuffer = new ArrayBuffer[Any] - valuesBuffer ++= values - memoryStore.putValues(blockId, valuesBuffer, level, returnValues = true).data - match { - case Left(values2) => - return Some(new BlockResult(values2, DataReadMethod.Disk, info.size)) - case _ => - throw new SparkException("Memory store did not return back an iterator") - } + val putResult = memoryStore.putIterator( + blockId, values, level, returnValues = true, allowPersistToDisk = false) + // The put may or may not have succeeded, depending on whether there was enough + // space to unroll the block. Either way, the put here should return an iterator. + putResult.data match { + case Left(it) => + return Some(new BlockResult(it, DataReadMethod.Disk, info.size)) + case _ => + // This only happens if we dropped the values back to disk (which is never) + throw new SparkException("Memory store did not return an iterator!") + } } else { return Some(new BlockResult(values, DataReadMethod.Disk, info.size)) } @@ -561,13 +562,14 @@ private[spark] class BlockManager( iter } - def put( + def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, - tellMaster: Boolean): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true, + effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { require(values != null, "Values is null") - doPut(blockId, IteratorValues(values), level, tellMaster) + doPut(blockId, IteratorValues(values), level, tellMaster, effectiveStorageLevel) } /** @@ -589,13 +591,14 @@ private[spark] class BlockManager( * Put a new block of values to the block manager. * Return a list of blocks updated as a result of this put. */ - def put( + def putArray( blockId: BlockId, - values: ArrayBuffer[Any], + values: Array[Any], level: StorageLevel, - tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true, + effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { require(values != null, "Values is null") - doPut(blockId, ArrayBufferValues(values), level, tellMaster) + doPut(blockId, ArrayValues(values), level, tellMaster, effectiveStorageLevel) } /** @@ -606,19 +609,33 @@ private[spark] class BlockManager( blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, - tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true, + effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { require(bytes != null, "Bytes is null") - doPut(blockId, ByteBufferValues(bytes), level, tellMaster) + doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel) } + /** + * Put the given block according to the given level in one of the block stores, replicating + * the values if necessary. + * + * The effective storage level refers to the level according to which the block will actually be + * handled. This allows the caller to specify an alternate behavior of doPut while preserving + * the original level specified by the user. + */ private def doPut( blockId: BlockId, data: BlockValues, level: StorageLevel, - tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true, + effectiveStorageLevel: Option[StorageLevel] = None) + : Seq[(BlockId, BlockStatus)] = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") + effectiveStorageLevel.foreach { level => + require(level != null && level.isValid, "Effective StorageLevel is null or invalid") + } // Return value val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] @@ -657,13 +674,16 @@ private[spark] class BlockManager( // Size of the block in bytes var size = 0L + // The level we actually use to put the block + val putLevel = effectiveStorageLevel.getOrElse(level) + // If we're storing bytes, then initiate the replication before storing them locally. // This is faster as data is already serialized and ready to send. val replicationFuture = data match { - case b: ByteBufferValues if level.replication > 1 => + case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, level) } + Future { replicate(blockId, bufferView, putLevel) } case _ => null } @@ -676,18 +696,18 @@ private[spark] class BlockManager( // returnValues - Whether to return the values put // blockStore - The type of storage to put these values into val (returnValues, blockStore: BlockStore) = { - if (level.useMemory) { + if (putLevel.useMemory) { // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. (true, memoryStore) - } else if (level.useOffHeap) { + } else if (putLevel.useOffHeap) { // Use tachyon for off-heap storage (false, tachyonStore) - } else if (level.useDisk) { + } else if (putLevel.useDisk) { // Don't get back the bytes from put unless we replicate them - (level.replication > 1, diskStore) + (putLevel.replication > 1, diskStore) } else { - assert(level == StorageLevel.NONE) + assert(putLevel == StorageLevel.NONE) throw new BlockException( blockId, s"Attempted to put block $blockId without specifying storage level!") } @@ -696,22 +716,22 @@ private[spark] class BlockManager( // Actually put the values val result = data match { case IteratorValues(iterator) => - blockStore.putValues(blockId, iterator, level, returnValues) - case ArrayBufferValues(array) => - blockStore.putValues(blockId, array, level, returnValues) + blockStore.putIterator(blockId, iterator, putLevel, returnValues) + case ArrayValues(array) => + blockStore.putArray(blockId, array, putLevel, returnValues) case ByteBufferValues(bytes) => bytes.rewind() - blockStore.putBytes(blockId, bytes, level) + blockStore.putBytes(blockId, bytes, putLevel) } size = result.size result.data match { - case Left (newIterator) if level.useMemory => valuesAfterPut = newIterator + case Left (newIterator) if putLevel.useMemory => valuesAfterPut = newIterator case Right (newBytes) => bytesAfterPut = newBytes case _ => } // Keep track of which blocks are dropped from memory - if (level.useMemory) { + if (putLevel.useMemory) { result.droppedBlocks.foreach { updatedBlocks += _ } } @@ -742,7 +762,7 @@ private[spark] class BlockManager( // Either we're storing bytes and we asynchronously started replication, or we're storing // values and need to serialize and replicate them now: - if (level.replication > 1) { + if (putLevel.replication > 1) { data match { case ByteBufferValues(bytes) => if (replicationFuture != null) { @@ -758,7 +778,7 @@ private[spark] class BlockManager( } bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } - replicate(blockId, bytesAfterPut, level) + replicate(blockId, bytesAfterPut, putLevel) logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) } @@ -766,7 +786,7 @@ private[spark] class BlockManager( BlockManager.dispose(bytesAfterPut) - if (level.replication > 1) { + if (putLevel.replication > 1) { logDebug("Putting block %s with replication took %s" .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } else { @@ -818,7 +838,7 @@ private[spark] class BlockManager( value: Any, level: StorageLevel, tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { - put(blockId, Iterator(value), level, tellMaster) + putIterator(blockId, Iterator(value), level, tellMaster) } /** @@ -829,7 +849,7 @@ private[spark] class BlockManager( */ def dropFromMemory( blockId: BlockId, - data: Either[ArrayBuffer[Any], ByteBuffer]): Option[BlockStatus] = { + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -853,7 +873,7 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data match { case Left(elements) => - diskStore.putValues(blockId, elements, level, returnValues = false) + diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => diskStore.putBytes(blockId, bytes, level) } @@ -1068,9 +1088,11 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { private val ID_GENERATOR = new IdGenerator + /** Return the total amount of storage memory available. */ private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) - (Runtime.getRuntime.maxMemory * memoryFraction).toLong + val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } def getHeartBeatFrequency(conf: SparkConf): Long = diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6aed322eeb185..de1cc5539fb48 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -336,6 +336,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case None => blockManagerIdByExecutor(id.executorId) = id } + + logInfo("Registering block manager %s with %s RAM".format( + id.hostPort, Utils.bytesToString(maxMemSize))) + blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor) } @@ -432,9 +436,6 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - logInfo("Registering block manager %s with %s RAM".format( - blockManagerId.hostPort, Utils.bytesToString(maxMem))) - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) def updateLastSeenMs() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index b9b53b1a2f118..69985c9759e2d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -37,15 +37,15 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues( + def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult - def putValues( + def putArray( blockId: BlockId, - values: ArrayBuffer[Any], + values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 673fc19c060a4..2e7ed7538e6e5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -43,7 +43,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - private val localDirs: Array[File] = createLocalDirs() + val localDirs: Array[File] = createLocalDirs() if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ebff0cb5ba153..c83261dd91b36 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -21,8 +21,6 @@ import java.io.{FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -30,7 +28,7 @@ import org.apache.spark.util.Utils /** * Stores BlockManager blocks on disk. */ -private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) +private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L) @@ -57,15 +55,15 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage PutResult(bytes.limit(), Right(bytes.duplicate())) } - override def putValues( + override def putArray( blockId: BlockId, - values: ArrayBuffer[Any], + values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - putValues(blockId, values.toIterator, level, returnValues) + putIterator(blockId, values.toIterator, level, returnValues) } - override def putValues( + override def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 71f66c826c5b3..28f675c2bbb1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -20,27 +20,45 @@ package org.apache.spark.storage import java.nio.ByteBuffer import java.util.LinkedHashMap +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.util.collection.SizeTrackingVector private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) /** - * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as + * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ -private class MemoryStore(blockManager: BlockManager, maxMemory: Long) +private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) extends BlockStore(blockManager) { + private val conf = blockManager.conf private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) + @volatile private var currentMemory = 0L - // Object used to ensure that only one thread is putting blocks and if necessary, dropping - // blocks from the memory store. - private val putLock = new Object() + + // Ensure only one thread is putting, and if necessary, dropping blocks at any given time + private val accountingLock = new Object + + // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // All accesses of this map are assumed to have manually synchronized on `accountingLock` + private val unrollMemoryMap = mutable.HashMap[Long, Long]() + + /** + * The amount of space ensured for unrolling values in memory, shared across all cores. + * This space is not reserved in advance, but allocated dynamically by dropping existing blocks. + */ + private val maxUnrollMemory: Long = { + val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2) + (maxMemory * unrollFraction).toLong + } logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) + /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: BlockId): Long = { @@ -55,20 +73,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) bytes.rewind() if (level.deserialized) { val values = blockManager.dataDeserialize(blockId, bytes) - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - val putAttempt = tryToPut(blockId, elements, sizeEstimate, deserialized = true) - PutResult(sizeEstimate, Left(values.toIterator), putAttempt.droppedBlocks) + putIterator(blockId, values, level, returnValues = true) } else { val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) } } - override def putValues( + override def putArray( blockId: BlockId, - values: ArrayBuffer[Any], + values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { if (level.deserialized) { @@ -82,14 +96,52 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putValues( + override def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - val valueEntries = new ArrayBuffer[Any]() - valueEntries ++= values - putValues(blockId, valueEntries, level, returnValues) + putIterator(blockId, values, level, returnValues, allowPersistToDisk = true) + } + + /** + * Attempt to put the given block in memory store. + * + * There may not be enough space to fully unroll the iterator in memory, in which case we + * optionally drop the values to disk if + * (1) the block's storage level specifies useDisk, and + * (2) `allowPersistToDisk` is true. + * + * One scenario in which `allowPersistToDisk` is false is when the BlockManager reads a block + * back from disk and attempts to cache it in memory. In this case, we should not persist the + * block back on disk again, as it is already in disk store. + */ + private[storage] def putIterator( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean, + allowPersistToDisk: Boolean): PutResult = { + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val unrolledValues = unrollSafely(blockId, values, droppedBlocks) + unrolledValues match { + case Left(arrayValues) => + // Values are fully unrolled in memory, so store them as an array + val res = putArray(blockId, arrayValues, level, returnValues) + droppedBlocks ++= res.droppedBlocks + PutResult(res.size, res.data, droppedBlocks) + case Right(iteratorValues) => + // Not enough space to unroll this block; drop to disk if applicable + logWarning(s"Not enough space to store block $blockId in memory! " + + s"Free memory is $freeMemory bytes.") + if (level.useDisk && allowPersistToDisk) { + logWarning(s"Persisting block $blockId to disk instead.") + val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) + PutResult(res.size, res.data, droppedBlocks) + } else { + PutResult(0, Left(iteratorValues), droppedBlocks) + } + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -99,7 +151,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (entry == null) { None } else if (entry.deserialized) { - Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) + Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[Array[Any]].iterator)) } else { Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data } @@ -112,7 +164,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (entry == null) { None } else if (entry.deserialized) { - Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) + Some(entry.value.asInstanceOf[Array[Any]].iterator) } else { val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data Some(blockManager.dataDeserialize(blockId, buffer)) @@ -140,6 +192,93 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo("MemoryStore cleared") } + /** + * Unroll the given block in memory safely. + * + * The safety of this operation refers to avoiding potential OOM exceptions caused by + * unrolling the entirety of the block in memory at once. This is achieved by periodically + * checking whether the memory restrictions for unrolling blocks are still satisfied, + * stopping immediately if not. This check is a safeguard against the scenario in which + * there is not enough free memory to accommodate the entirety of a single block. + * + * This method returns either an array with the contents of the entire block or an iterator + * containing the values of the block (if the array would have exceeded available memory). + */ + def unrollSafely( + blockId: BlockId, + values: Iterator[Any], + droppedBlocks: ArrayBuffer[(BlockId, BlockStatus)]) + : Either[Array[Any], Iterator[Any]] = { + + // Number of elements unrolled so far + var elementsUnrolled = 0 + // Whether there is still enough memory for us to continue unrolling this block + var keepUnrolling = true + // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + // How often to check whether we need to request more memory + val memoryCheckPeriod = 16 + // Memory currently reserved by this thread for this particular unrolling operation + var memoryThreshold = initialMemoryThreshold + // Memory to request as a multiple of current vector size + val memoryGrowthFactor = 1.5 + // Previous unroll memory held by this thread, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisThread + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[Any] + + // Request enough memory to begin unrolling + keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + + // Unroll this block safely, checking whether we have exceeded our threshold periodically + try { + while (values.hasNext && keepUnrolling) { + vector += values.next() + if (elementsUnrolled % memoryCheckPeriod == 0) { + // If our vector's size has exceeded the threshold, request more memory + val currentSize = vector.estimateSize() + if (currentSize >= memoryThreshold) { + val amountToRequest = (currentSize * (memoryGrowthFactor - 1)).toLong + // Hold the accounting lock, in case another thread concurrently puts a block that + // takes up the unrolling space we just ensured here + accountingLock.synchronized { + if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + // If the first request is not granted, try again after ensuring free space + // If there is still not enough space, give up and drop the partition + val spaceToEnsure = maxUnrollMemory - currentUnrollMemory + if (spaceToEnsure > 0) { + val result = ensureFreeSpace(blockId, spaceToEnsure) + droppedBlocks ++= result.droppedBlocks + } + keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + } + } + // New threshold is currentSize * memoryGrowthFactor + memoryThreshold = currentSize + amountToRequest + } + } + elementsUnrolled += 1 + } + + if (keepUnrolling) { + // We successfully unrolled the entirety of this block + Left(vector.toArray) + } else { + // We ran out of space while unrolling the values for this block + Right(vector.iterator ++ values) + } + + } finally { + // If we return an array, the values returned do not depend on the underlying vector and + // we can immediately free up space for other threads. Otherwise, if we return an iterator, + // we release the memory claimed by this thread later on when the task finishes. + if (keepUnrolling) { + val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved + releaseUnrollMemoryForThisThread(amountToRelease) + } + } + } + /** * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. */ @@ -149,10 +288,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Try to put in a set of values, if we can free up enough space. The value should either be - * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) - * size must also be passed by the caller. + * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size + * must also be passed by the caller. * - * Lock on the object putLock to ensure that all the put requests and its associated block + * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for * another block. @@ -174,7 +313,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) var putSuccess = false val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - putLock.synchronized { + accountingLock.synchronized { val freeSpaceResult = ensureFreeSpace(blockId, size) val enoughFreeSpace = freeSpaceResult.success droppedBlocks ++= freeSpaceResult.droppedBlocks @@ -193,7 +332,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. val data = if (deserialized) { - Left(value.asInstanceOf[ArrayBuffer[Any]]) + Left(value.asInstanceOf[Array[Any]]) } else { Right(value.asInstanceOf[ByteBuffer].duplicate()) } @@ -210,12 +349,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assume that a lock is held by the caller to ensure only one thread is dropping blocks. - * Otherwise, the freed space may fill up before the caller puts in their new value. + * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping + * blocks. Otherwise, the freed space may fill up before the caller puts in their new value. * * Return whether there is enough free space, along with the blocks dropped in the process. */ - private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): ResultWithDroppedBlocks = { + private def ensureFreeSpace( + blockIdToAdd: BlockId, + space: Long): ResultWithDroppedBlocks = { logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] @@ -225,9 +366,12 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return ResultWithDroppedBlocks(success = false, droppedBlocks) } - if (maxMemory - currentMemory < space) { + // Take into account the amount of memory currently occupied by unrolling blocks + val actualFreeMemory = freeMemory - currentUnrollMemory + + if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[BlockId]() + val selectedBlocks = new ArrayBuffer[BlockId] var selectedMemory = 0L // This is synchronized to ensure that the set of entries is not changed @@ -235,7 +379,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + while (actualFreeMemory + selectedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { @@ -245,7 +389,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - if (maxMemory - (currentMemory - selectedMemory) >= space) { + if (actualFreeMemory + selectedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } @@ -254,7 +398,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) // future safety. if (entry != null) { val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + Left(entry.value.asInstanceOf[Array[Any]]) } else { Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) } @@ -275,8 +419,56 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) override def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } + + /** + * Reserve additional memory for unrolling blocks used by this thread. + * Return whether the request is granted. + */ + private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + accountingLock.synchronized { + val granted = freeMemory > currentUnrollMemory + memory + if (granted) { + val threadId = Thread.currentThread().getId + unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + } + granted + } + } + + /** + * Release memory used by this thread for unrolling blocks. + * If the amount is not specified, remove the current thread's allocation altogether. + */ + private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { + val threadId = Thread.currentThread().getId + accountingLock.synchronized { + if (memory < 0) { + unrollMemoryMap.remove(threadId) + } else { + unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory + // If this thread claims no more unroll memory, release it completely + if (unrollMemoryMap(threadId) <= 0) { + unrollMemoryMap.remove(threadId) + } + } + } + } + + /** + * Return the amount of memory currently occupied for unrolling blocks across all threads. + */ + private[spark] def currentUnrollMemory: Long = accountingLock.synchronized { + unrollMemoryMap.values.sum + } + + /** + * Return the amount of memory currently occupied for unrolling blocks by this thread. + */ + private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + } } -private case class ResultWithDroppedBlocks( +private[spark] case class ResultWithDroppedBlocks( success: Boolean, droppedBlocks: Seq[(BlockId, BlockStatus)]) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index d8ff4ff6bd42c..932b5616043b4 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -20,8 +20,6 @@ package org.apache.spark.storage import java.io.IOException import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import tachyon.client.{ReadType, WriteType} import org.apache.spark.Logging @@ -30,7 +28,7 @@ import org.apache.spark.util.Utils /** * Stores BlockManager blocks on Tachyon. */ -private class TachyonStore( +private[spark] class TachyonStore( blockManager: BlockManager, tachyonManager: TachyonBlockManager) extends BlockStore(blockManager: BlockManager) with Logging { @@ -45,15 +43,15 @@ private class TachyonStore( putIntoTachyonStore(blockId, bytes, returnValues = true) } - override def putValues( + override def putArray( blockId: BlockId, - values: ArrayBuffer[Any], + values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - putValues(blockId, values.toIterator, level, returnValues) + putIterator(blockId, values.toIterator, level, returnValues) } - override def putValues( + override def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 328be158db680..75c2e09a6bbb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, tellMaster = true) + manager.putIterator(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 37708d75489c8..9ced9b8107ebf 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -20,9 +20,9 @@ package org.apache.spark.ui private[spark] object ToolTips { val SCHEDULER_DELAY = """Scheduler delay includes time to ship the task from the scheduler to - the executor, and time the time to send a message from the executor to the scheduler stating - that the task has completed. When the scheduler becomes overloaded, task completion messages - become queued up, and scheduler delay increases.""" + the executor, and time to send the task result from the executor to the scheduler. If + scheduler delay is large, consider decreasing the size of tasks or decreasing the size + of task results.""" val INPUT = "Bytes read from Hadoop or from Spark storage." diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 9cb50d9b83dda..715cc2f4df8dd 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -136,11 +136,20 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("") + def uiRoot: String = { + if (System.getenv("APPLICATION_WEB_PROXY_BASE") != null) { + System.getenv("APPLICATION_WEB_PROXY_BASE") + } else if (System.getProperty("spark.ui.proxyBase") != null) { + System.getProperty("spark.ui.proxyBase") + } + else { + "" + } + } def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource - val commonHeaderNodes = { + def commonHeaderNodes = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 52020954ea57c..0cc51c873727d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.xml.Node import org.apache.spark.ui.{ToolTips, UIUtils} +import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Page showing executor summary */ @@ -64,11 +65,9 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { executorIdToAddress.put(executorId, address) } - val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) - executorIdToSummary match { - case Some(x) => - x.toSeq.sortBy(_._1).map { case (k, v) => { - // scalastyle:off + listener.stageIdToData.get(stageId) match { + case Some(stageData: StageUIData) => + stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} @@ -76,16 +75,20 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - {Utils.bytesToString(v.inputBytes)} - {Utils.bytesToString(v.shuffleRead)} - {Utils.bytesToString(v.shuffleWrite)} - {Utils.bytesToString(v.memoryBytesSpilled)} - {Utils.bytesToString(v.diskBytesSpilled)} + + {Utils.bytesToString(v.inputBytes)} + + {Utils.bytesToString(v.shuffleRead)} + + {Utils.bytesToString(v.shuffleWrite)} + + {Utils.bytesToString(v.memoryBytesSpilled)} + + {Utils.bytesToString(v.diskBytesSpilled)} - // scalastyle:on } - } - case _ => Seq[Node]() + case None => + Seq.empty[Node] } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 2286a7f952f28..efb527b4f03e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -25,6 +25,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId +import org.apache.spark.ui.jobs.UIData._ /** * :: DeveloperApi :: @@ -35,7 +36,7 @@ import org.apache.spark.storage.BlockManagerId * updating the internal data structures concurrently. */ @DeveloperApi -class JobProgressListener(conf: SparkConf) extends SparkListener { +class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ @@ -46,20 +47,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { val completedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() - // TODO: Should probably consolidate all following into a single hash map. - val stageIdToTime = HashMap[Int, Long]() - val stageIdToInputBytes = HashMap[Int, Long]() - val stageIdToShuffleRead = HashMap[Int, Long]() - val stageIdToShuffleWrite = HashMap[Int, Long]() - val stageIdToMemoryBytesSpilled = HashMap[Int, Long]() - val stageIdToDiskBytesSpilled = HashMap[Int, Long]() - val stageIdToTasksActive = HashMap[Int, HashMap[Long, TaskInfo]]() - val stageIdToTasksComplete = HashMap[Int, Int]() - val stageIdToTasksFailed = HashMap[Int, Int]() - val stageIdToTaskData = HashMap[Int, HashMap[Long, TaskUIData]]() - val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]() - val stageIdToPool = HashMap[Int, String]() - val stageIdToDescription = HashMap[Int, String]() + val stageIdToData = new HashMap[Int, StageUIData] + val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() @@ -71,8 +60,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo val stageId = stage.stageId - // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage - poolToActiveStages(stageIdToPool(stageId)).remove(stageId) + val stageData = stageIdToData.getOrElseUpdate(stageId, { + logWarning("Stage completed for unknown stage " + stageId) + new StageUIData + }) + + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { completedStages += stage @@ -87,21 +80,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => - stageIdToTime.remove(s.stageId) - stageIdToInputBytes.remove(s.stageId) - stageIdToShuffleRead.remove(s.stageId) - stageIdToShuffleWrite.remove(s.stageId) - stageIdToMemoryBytesSpilled.remove(s.stageId) - stageIdToDiskBytesSpilled.remove(s.stageId) - stageIdToTasksActive.remove(s.stageId) - stageIdToTasksComplete.remove(s.stageId) - stageIdToTasksFailed.remove(s.stageId) - stageIdToTaskData.remove(s.stageId) - stageIdToExecutorSummaries.remove(s.stageId) - stageIdToPool.remove(s.stageId) - stageIdToDescription.remove(s.stageId) - } + stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) } stages.trimStart(toRemove) } } @@ -114,26 +93,27 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { val poolName = Option(stageSubmitted.properties).map { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - stageIdToPool(stage.stageId) = poolName - val description = Option(stageSubmitted.properties).flatMap { + val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData) + stageData.schedulingPool = poolName + + stageData.description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - description.map(d => stageIdToDescription(stage.stageId) = d) val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]()) stages(stage.stageId) = stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val sid = taskStart.stageId val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashMap[Long, TaskInfo]()) - tasksActive(taskInfo.taskId) = taskInfo - val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(taskInfo.taskId) = new TaskUIData(taskInfo) - stageIdToTaskData(sid) = taskMap + val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, { + logWarning("Task start for unknown stage " + taskStart.stageId) + new StageUIData + }) + stageData.numActiveTasks += 1 + stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } } @@ -143,88 +123,76 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - val sid = taskEnd.stageId val info = taskEnd.taskInfo - if (info != null) { + val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, { + logWarning("Task end for unknown stage " + taskEnd.stageId) + new StageUIData + }) + // create executor summary map if necessary - val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid, - op = new HashMap[String, ExecutorSummary]()) + val executorSummaryMap = stageData.executorSummary executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - val executorSummary = executorSummaryMap.get(info.executorId) - executorSummary match { - case Some(y) => { - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } - - // update duration - y.taskTime += info.duration - - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + executorSummaryMap.get(info.executorId).foreach { y => + // first update failed-task, succeed-task + taskEnd.reason match { + case Success => + y.succeededTasks += 1 + case _ => + y.failedTasks += 1 + } + + // update duration + y.taskTime += info.duration + + val metrics = taskEnd.taskMetrics + if (metrics != null) { + metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } + metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } + metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } + y.memoryBytesSpilled += metrics.memoryBytesSpilled + y.diskBytesSpilled += metrics.diskBytesSpilled } - case _ => {} } - val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashMap[Long, TaskInfo]()) - // Remove by taskId, rather than by TaskInfo, in case the TaskInfo is from storage - tasksActive.remove(info.taskId) + stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => - stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 + stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + stageData.numFailedTasks += 1 (Some(e.toErrorString), e.metrics) case e: TaskFailedReason => // All other failure cases - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + stageData.numFailedTasks += 1 (Some(e.toErrorString), None) } - stageIdToTime.getOrElseUpdate(sid, 0L) - val time = metrics.map(_.executorRunTime).getOrElse(0L) - stageIdToTime(sid) += time - stageIdToInputBytes.getOrElseUpdate(sid, 0L) + val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += taskRunTime val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageIdToInputBytes(sid) += inputBytes + stageData.inputBytes += inputBytes - stageIdToShuffleRead.getOrElseUpdate(sid, 0L) val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageIdToShuffleRead(sid) += shuffleRead + stageData.shuffleReadBytes += shuffleRead - stageIdToShuffleWrite.getOrElseUpdate(sid, 0L) val shuffleWrite = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageIdToShuffleWrite(sid) += shuffleWrite + stageData.shuffleWriteBytes += shuffleWrite - stageIdToMemoryBytesSpilled.getOrElseUpdate(sid, 0L) val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageIdToMemoryBytesSpilled(sid) += memoryBytesSpilled + stageData.memoryBytesSpilled += memoryBytesSpilled - stageIdToDiskBytesSpilled.getOrElseUpdate(sid, 0L) val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageIdToDiskBytesSpilled(sid) += diskBytesSpilled + stageData.diskBytesSpilled += diskBytesSpilled - val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(info.taskId) = new TaskUIData(info, metrics, errorMessage) - stageIdToTaskData(sid) = taskMap + stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) } - } + } // end of onTaskEnd override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { @@ -252,12 +220,6 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { } -@DeveloperApi -case class TaskUIData( - taskInfo: TaskInfo, - taskMetrics: Option[TaskMetrics] = None, - errorMessage: Option[String] = None) - private object JobProgressListener { val DEFAULT_POOL_NAME = "default" val DEFAULT_RETAINED_STAGES = 1000 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8c3821bd7c3eb..cab26b9e2f7d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ @@ -34,8 +35,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val stageId = request.getParameter("id").toInt + val stageDataOption = listener.stageIdToData.get(stageId) - if (!listener.stageIdToTaskData.contains(stageId)) { + if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { val content =

Summary Metrics

No tasks have started yet @@ -45,23 +47,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { "Details for Stage %s".format(stageId), parent.headerTabs, parent) } - val tasks = listener.stageIdToTaskData(stageId).values.toSeq.sortBy(_.taskInfo.launchTime) + val stageData = stageDataOption.get + val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) - val inputBytes = listener.stageIdToInputBytes.getOrElse(stageId, 0L) - val hasInput = inputBytes > 0 - val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L) - val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L) - val hasShuffleWrite = shuffleWriteBytes > 0 - val memoryBytesSpilled = listener.stageIdToMemoryBytesSpilled.getOrElse(stageId, 0L) - val diskBytesSpilled = listener.stageIdToDiskBytesSpilled.getOrElse(stageId, 0L) - val hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 - - var activeTime = 0L - val now = System.currentTimeMillis - val tasksActive = listener.stageIdToTasksActive(stageId).values - tasksActive.foreach(activeTime += _.timeRunning(now)) + val hasInput = stageData.inputBytes > 0 + val hasShuffleRead = stageData.shuffleReadBytes > 0 + val hasShuffleWrite = stageData.shuffleWriteBytes > 0 + val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 // scalastyle:off val summary = @@ -69,34 +62,34 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
  • Total task time across all tasks: - {UIUtils.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} + {UIUtils.formatDuration(stageData.executorRunTime)}
  • {if (hasInput)
  • Input: - {Utils.bytesToString(inputBytes)} + {Utils.bytesToString(stageData.inputBytes)}
  • } {if (hasShuffleRead)
  • Shuffle read: - {Utils.bytesToString(shuffleReadBytes)} + {Utils.bytesToString(stageData.shuffleReadBytes)}
  • } {if (hasShuffleWrite)
  • Shuffle write: - {Utils.bytesToString(shuffleWriteBytes)} + {Utils.bytesToString(stageData.shuffleWriteBytes)}
  • } {if (hasBytesSpilled)
  • Shuffle spill (memory): - {Utils.bytesToString(memoryBytesSpilled)} + {Utils.bytesToString(stageData.memoryBytesSpilled)}
  • Shuffle spill (disk): - {Utils.bytesToString(diskBytesSpilled)} + {Utils.bytesToString(stageData.diskBytesSpilled)}
  • }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index fd8d0b5cdde00..3dcfaf76e4aba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -17,12 +17,12 @@ package org.apache.spark.ui.jobs -import java.util.Date - -import scala.collection.mutable.HashMap import scala.xml.Node +import scala.xml.Text + +import java.util.Date -import org.apache.spark.scheduler.{StageInfo, TaskInfo} +import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -71,14 +71,14 @@ private[ui] class StageTableBase( } - private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = + private def makeProgressBar(started: Int, completed: Int, failed: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100)
- {completed}/{total} {failed} + {completed}/{total} { if (failed > 0) s"($failed failed)" else "" }
@@ -100,21 +100,42 @@ private[ui] class StageTableBase( {s.name} + val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { - +show details - - + +details + ++ + + } + + val stageDesc = for { + stageData <- listener.stageIdToData.get(s.stageId) + desc <- stageData.description + } yield { +
{desc}
} - listener.stageIdToDescription.get(s.stageId) - .map(d =>
{d}
{nameLink} {killLink}
) - .getOrElse(
{nameLink} {killLink} {details}
) +
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
} protected def stageRow(s: StageInfo): Seq[Node] = { - val poolName = listener.stageIdToPool.get(s.stageId) + val stageDataOption = listener.stageIdToData.get(s.stageId) + if (stageDataOption.isEmpty) { + return {s.stageId}No data available for this stage + } + + val stageData = stageDataOption.get val submissionTime = s.submissionTime match { case Some(t) => UIUtils.formatDate(new Date(t)) case None => "Unknown" @@ -124,35 +145,20 @@ private[ui] class StageTableBase( if (finishTime > t) finishTime - t else System.currentTimeMillis - t } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") - val startedTasks = - listener.stageIdToTasksActive.getOrElse(s.stageId, HashMap[Long, TaskInfo]()).size - val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0) - val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match { - case f if f > 0 => "(%s failed)".format(f) - case _ => "" - } - val totalTasks = s.numTasks - val inputSortable = listener.stageIdToInputBytes.getOrElse(s.stageId, 0L) - val inputRead = inputSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) - val shuffleRead = shuffleReadSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) - val shuffleWrite = shuffleWriteSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } + + val inputRead = stageData.inputBytes + val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" + val shuffleRead = stageData.shuffleReadBytes + val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" + val shuffleWrite = stageData.shuffleWriteBytes + val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" + {s.stageId} ++ {if (isFairScheduler) { - {poolName.get} + .format(UIUtils.prependBaseUri(basePath), stageData.schedulingPool)}> + {stageData.schedulingPool} } else { @@ -162,11 +168,12 @@ private[ui] class StageTableBase( {submissionTime} {formattedDuration} - {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} + {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + stageData.numFailedTasks, s.numTasks)} - {inputRead} - {shuffleRead} - {shuffleWrite} + {inputReadWithUnit} + {shuffleReadWithUnit} + {shuffleWriteWithUnit} } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala new file mode 100644 index 0000000000000..be11a11695b01 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.TaskInfo + +import scala.collection.mutable.HashMap + +private[jobs] object UIData { + + class ExecutorSummary { + var taskTime : Long = 0 + var failedTasks : Int = 0 + var succeededTasks : Int = 0 + var inputBytes : Long = 0 + var shuffleRead : Long = 0 + var shuffleWrite : Long = 0 + var memoryBytesSpilled : Long = 0 + var diskBytesSpilled : Long = 0 + } + + class StageUIData { + var numActiveTasks: Int = _ + var numCompleteTasks: Int = _ + var numFailedTasks: Int = _ + + var executorRunTime: Long = _ + + var inputBytes: Long = _ + var shuffleReadBytes: Long = _ + var shuffleWriteBytes: Long = _ + var memoryBytesSpilled: Long = _ + var diskBytesSpilled: Long = _ + + var schedulingPool: String = "" + var description: Option[String] = None + + var taskData = new HashMap[Long, TaskUIData] + var executorSummary = new HashMap[String, ExecutorSummary] + } + + case class TaskUIData( + taskInfo: TaskInfo, + taskMetrics: Option[TaskMetrics] = None, + errorMessage: Option[String] = None) +} diff --git a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala b/core/src/main/scala/org/apache/spark/util/CollectionsUtils.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala rename to core/src/main/scala/org/apache/spark/util/CollectionsUtils.scala index e4c254b9dd6b9..85da2842e8ddb 100644 --- a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala +++ b/core/src/main/scala/org/apache/spark/util/CollectionsUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.util import java.util -import scala.Array -import scala.reflect._ +import scala.reflect.{classTag, ClassTag} private[spark] object CollectionsUtils { def makeBinarySearch[K : Ordering : ClassTag] : (Array[K], K) => Int = { + // For primitive keys, we can use the natural ordering. Otherwise, use the Ordering comparator. classTag[K] match { case ClassTag.Float => (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Float]], x.asInstanceOf[Float]) @@ -40,7 +40,8 @@ private[spark] object CollectionsUtils { case ClassTag.Long => (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Long]], x.asInstanceOf[Long]) case _ => - (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x) + val comparator = implicitly[Ordering[K]].asInstanceOf[java.util.Comparator[Any]] + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x, comparator) } } } diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6a95dc06e155d..9dcdafdd6350e 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -196,6 +196,5 @@ private[spark] class FileLogger( def stop() { hadoopDataStream.foreach(_.close()) writer.foreach(_.close()) - fileSystem.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 47eb44b530379..bb6079154aafe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -237,7 +237,6 @@ private[spark] object JsonProtocol { def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = { ("Shuffle Finish Time" -> shuffleReadMetrics.shuffleFinishTime) ~ - ("Total Blocks Fetched" -> shuffleReadMetrics.totalBlocksFetched) ~ ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ @@ -258,7 +257,8 @@ private[spark] object JsonProtocol { val reason = Utils.getFormattedClassName(taskEndReason) val json = taskEndReason match { case fetchFailed: FetchFailed => - val blockManagerAddress = blockManagerIdToJson(fetchFailed.bmAddress) + val blockManagerAddress = Option(fetchFailed.bmAddress). + map(blockManagerIdToJson).getOrElse(JNothing) ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ @@ -527,8 +527,9 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - metrics.shuffleReadMetrics = - Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson) + Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => + metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) + } metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = @@ -547,7 +548,6 @@ private[spark] object JsonProtocol { def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = { val metrics = new ShuffleReadMetrics metrics.shuffleFinishTime = (json \ "Shuffle Finish Time").extract[Long] - metrics.totalBlocksFetched = (json \ "Total Blocks Fetched").extract[Int] metrics.remoteBlocksFetched = (json \ "Remote Blocks Fetched").extract[Int] metrics.localBlocksFetched = (json \ "Local Blocks Fetched").extract[Int] metrics.fetchWaitTime = (json \ "Fetch Wait Time").extract[Long] diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 08465575309c6..bce3b3afe9aba 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -180,7 +180,7 @@ private[spark] object SizeEstimator extends Logging { } } - // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. + // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. private val ARRAY_SIZE_FOR_SAMPLING = 200 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a2454e120a8ab..8cbb9050f393b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} +import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} import scala.collection.JavaConversions._ import scala.collection.Map @@ -44,7 +44,7 @@ import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ -private[spark] case class CallSite(val short: String, val long: String) +private[spark] case class CallSite(shortForm: String, longForm: String) /** * Various utility methods used by Spark. @@ -553,19 +553,19 @@ private[spark] object Utils extends Logging { new ThreadFactoryBuilder().setDaemon(true) /** - * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. + * Create a thread factory that names threads with a prefix and also sets the threads to daemon. */ - def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { - val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() - Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + def namedThreadFactory(prefix: String): ThreadFactory = { + daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() } /** - * Return the string to tell how long has passed in milliseconds. + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def getUsedTimeMs(startTimeMs: Long): String = { - " " + (System.currentTimeMillis - startTimeMs) + " ms" + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] } /** @@ -573,10 +573,17 @@ private[spark] object Utils extends Logging { * unique, sequentially assigned integer. */ def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { - val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + val threadFactory = namedThreadFactory(prefix) Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] } + /** + * Return the string to tell how long has passed in milliseconds. + */ + def getUsedTimeMs(startTimeMs: Long): String = { + " " + (System.currentTimeMillis - startTimeMs) + " ms" + } + private def listFilesSafely(file: File): Seq[File] = { val files = file.listFiles() if (files == null) { @@ -809,7 +816,12 @@ private[spark] object Utils extends Logging { */ def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() - .filterNot(_.getMethodName.contains("getStackTrace")) + .filterNot { ste:StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus + // frames. This is intended to ensure that we don't crash in these situations by + // ignoring any frames that we can't examine. + (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) + } // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD @@ -843,8 +855,8 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt CallSite( - short = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), - long = callStack.take(callStackDepth).mkString("\n")) + shortForm = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + longForm = callStack.take(callStackDepth).mkString("\n")) } /** Return a string containing part of a file from byte 'start' to 'end'. */ @@ -1286,4 +1298,19 @@ private[spark] object Utils extends Logging { } } + /** Return a nice string representation of the exception, including the stack trace. */ + def exceptionString(e: Exception): String = { + if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) + } + + /** Return a nice string representation of the exception, including the stack trace. */ + def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 1a6f1c2b55799..290282c9c2e28 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -254,26 +254,21 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) * Return an iterator of the map in sorted order. This provides a way to sort the map without * using additional memory, at the expense of destroying the validity of the map. */ - def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { + def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = { destroyed = true // Pack KV pairs into the front of the underlying array var keyIndex, newIndex = 0 while (keyIndex < capacity) { if (data(2 * keyIndex) != null) { - data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) + data(2 * newIndex) = data(2 * keyIndex) + data(2 * newIndex + 1) = data(2 * keyIndex + 1) newIndex += 1 } keyIndex += 1 } assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) - // Sort by the given ordering - val rawOrdering = new Comparator[AnyRef] { - def compare(x: AnyRef, y: AnyRef): Int = { - cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) - } - } - Arrays.sort(data, 0, newIndex, rawOrdering) + new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator) new Iterator[(K, V)] { var i = 0 @@ -284,7 +279,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) nullValueReady = false (null.asInstanceOf[K], nullValue) } else { - val item = data(i).asInstanceOf[(K, V)] + val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V]) i += 1 item } diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala new file mode 100644 index 0000000000000..d44e15e3c97ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** + * An append-only buffer similar to ArrayBuffer, but more memory-efficient for small buffers. + * ArrayBuffer always allocates an Object array to store the data, with 16 entries by default, + * so it has about 80-100 bytes of overhead. In contrast, CompactBuffer can keep up to two + * elements in fields of the main object, and only allocates an Array[AnyRef] if there are more + * entries than that. This makes it more efficient for operations like groupBy where we expect + * some keys to have very few elements. + */ +private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { + // First two elements + private var element0: T = _ + private var element1: T = _ + + // Number of elements, including our two in the main object + private var curSize = 0 + + // Array for extra elements + private var otherElements: Array[AnyRef] = null + + def apply(position: Int): T = { + if (position < 0 || position >= curSize) { + throw new IndexOutOfBoundsException + } + if (position == 0) { + element0 + } else if (position == 1) { + element1 + } else { + otherElements(position - 2).asInstanceOf[T] + } + } + + private def update(position: Int, value: T): Unit = { + if (position < 0 || position >= curSize) { + throw new IndexOutOfBoundsException + } + if (position == 0) { + element0 = value + } else if (position == 1) { + element1 = value + } else { + otherElements(position - 2) = value.asInstanceOf[AnyRef] + } + } + + def += (value: T): CompactBuffer[T] = { + val newIndex = curSize + if (newIndex == 0) { + element0 = value + curSize = 1 + } else if (newIndex == 1) { + element1 = value + curSize = 2 + } else { + growToSize(curSize + 1) + otherElements(newIndex - 2) = value.asInstanceOf[AnyRef] + } + this + } + + def ++= (values: TraversableOnce[T]): CompactBuffer[T] = { + values match { + // Optimize merging of CompactBuffers, used in cogroup and groupByKey + case compactBuf: CompactBuffer[T] => + val oldSize = curSize + // Copy the other buffer's size and elements to local variables in case it is equal to us + val itsSize = compactBuf.curSize + val itsElements = compactBuf.otherElements + growToSize(curSize + itsSize) + if (itsSize == 1) { + this(oldSize) = compactBuf.element0 + } else if (itsSize == 2) { + this(oldSize) = compactBuf.element0 + this(oldSize + 1) = compactBuf.element1 + } else if (itsSize > 2) { + this(oldSize) = compactBuf.element0 + this(oldSize + 1) = compactBuf.element1 + // At this point our size is also above 2, so just copy its array directly into ours. + // Note that since we added two elements above, the index in this.otherElements that we + // should copy to is oldSize. + System.arraycopy(itsElements, 0, otherElements, oldSize, itsSize - 2) + } + + case _ => + values.foreach(e => this += e) + } + this + } + + override def length: Int = curSize + + override def size: Int = curSize + + override def iterator: Iterator[T] = new Iterator[T] { + private var pos = 0 + override def hasNext: Boolean = pos < curSize + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException + } + pos += 1 + apply(pos - 1) + } + } + + /** Increase our size to newSize and grow the backing array if needed. */ + private def growToSize(newSize: Int): Unit = { + if (newSize < 0) { + throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements") + } + val capacity = if (otherElements != null) otherElements.length + 2 else 2 + if (newSize > capacity) { + var newArrayLen = 8 + while (newSize - 2 > newArrayLen) { + newArrayLen *= 2 + if (newArrayLen == Int.MinValue) { + // Prevent overflow if we double from 2^30 to 2^31, which will become Int.MinValue. + // Note that we set the new array length to Int.MaxValue - 2 so that our capacity + // calculation above still gives a positive integer. + newArrayLen = Int.MaxValue - 2 + } + } + val newArray = new Array[AnyRef](newArrayLen) + if (otherElements != null) { + System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) + } + otherElements = newArray + } + curSize = newSize + } +} + +private[spark] object CompactBuffer { + def apply[T](): CompactBuffer[T] = new CompactBuffer[T] + + def apply[T](value: T): CompactBuffer[T] = { + val buf = new CompactBuffer[T] + buf += value + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 292d0962f4fdb..6f263c39d1435 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -30,6 +30,7 @@ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator /** * :: DeveloperApi :: @@ -66,8 +67,6 @@ class ExternalAppendOnlyMap[K, V, C]( blockManager: BlockManager = SparkEnv.get.blockManager) extends Iterable[(K, C)] with Serializable with Logging { - import ExternalAppendOnlyMap._ - private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -75,7 +74,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Collective memory threshold shared across all running tasks private val maxMemoryThreshold = { - val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3) + val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } @@ -105,48 +104,75 @@ class ExternalAppendOnlyMap[K, V, C]( private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 - private val comparator = new KCComparator[K, C] + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + private val threadId = Thread.currentThread().getId /** * Insert the given key and value into the map. + */ + def insert(key: K, value: V): Unit = { + insertAll(Iterator((key, value))) + } + + /** + * Insert the given iterator of keys and values into the map. * - * If the underlying map is about to grow, check if the global pool of shuffle memory has + * When the underlying map needs to grow, check if the global pool of shuffle memory has * enough room for this to happen. If so, allocate the memory required to grow the map; * otherwise, spill the in-memory map to disk. * * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. */ - def insert(key: K, value: V) { + def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + // An update function for the map that we reuse across entries to avoid allocating + // a new closure each time + var curEntry: Product2[K, V] = null val update: (Boolean, C) => C = (hadVal, oldVal) => { - if (hadVal) mergeValue(oldVal, value) else createCombiner(value) + if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2) } - if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { - val mapSize = currentMap.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // this map to grow and, if possible, allocate the required amount - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - - // Assume map growth factor is 2x - shouldSpill = availableMemory < mapSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = mapSize * 2 + + while (entries.hasNext) { + curEntry = entries.next() + if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { + val mapSize = currentMap.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // this map to grow and, if possible, allocate the required amount + shuffleMemoryMap.synchronized { + val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) + + // Assume map growth factor is 2x + shouldSpill = availableMemory < mapSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = mapSize * 2 + } + } + // Do not synchronize spills + if (shouldSpill) { + spill(mapSize) } } - // Do not synchronize spills - if (shouldSpill) { - spill(mapSize) - } + currentMap.changeValue(curEntry._1, update) + numPairsInMemory += 1 } - currentMap.changeValue(key, update) - numPairsInMemory += 1 + } + + /** + * Insert the given iterable of keys and values into the map. + * + * When the underlying map needs to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insertAll(entries: Iterable[Product2[K, V]]): Unit = { + insertAll(entries.iterator) } /** @@ -154,8 +180,8 @@ class ExternalAppendOnlyMap[K, V, C]( */ private def spill(mapSize: Long) { spillCount += 1 - logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) var objectsWritten = 0 @@ -173,7 +199,7 @@ class ExternalAppendOnlyMap[K, V, C]( } try { - val it = currentMap.destructiveSortedIterator(comparator) + val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { val kv = it.next() writer.write(kv) @@ -231,7 +257,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(comparator) + private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -252,7 +278,7 @@ class ExternalAppendOnlyMap[K, V, C]( if (it.hasNext) { var kc = it.next() kcPairs += kc - val minHash = getKeyHashCode(kc) + val minHash = hashKey(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() kcPairs += kc @@ -268,10 +294,10 @@ class ExternalAppendOnlyMap[K, V, C]( private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { var i = 0 while (i < buffer.pairs.length) { - val (k, c) = buffer.pairs(i) - if (k == key) { + val pair = buffer.pairs(i) + if (pair._1 == key) { buffer.pairs.remove(i) - return mergeCombiners(baseCombiner, c) + return mergeCombiners(baseCombiner, pair._2) } i += 1 } @@ -293,10 +319,12 @@ class ExternalAppendOnlyMap[K, V, C]( } // Select a key from the StreamBuffer that holds the lowest key hash val minBuffer = mergeHeap.dequeue() - val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + val minPairs = minBuffer.pairs + val minHash = minBuffer.minKeyHash val minPair = minPairs.remove(0) - var (minKey, minCombiner) = minPair - assert(getKeyHashCode(minPair) == minHash) + val minKey = minPair._1 + var minCombiner = minPair._2 + assert(hashKey(minPair) == minHash) // For all other streams that may have this key (i.e. have the same minimum key hash), // merge in the corresponding value (if any) from that stream @@ -337,7 +365,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Invalid if there are no more pairs in this stream def minKeyHash: Int = { assert(pairs.length > 0) - getKeyHashCode(pairs.head) + hashKey(pairs.head) } override def compareTo(other: StreamBuffer): Int = { @@ -421,25 +449,27 @@ class ExternalAppendOnlyMap[K, V, C]( file.delete() } } + + /** Convenience function to hash the given (K, C) pair by the key. */ + private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) } private[spark] object ExternalAppendOnlyMap { /** - * Return the key hash code of the given (key, combiner) pair. - * If the key is null, return a special hash code. + * Return the hash code of the given object. If the object is null, return a special hash code. */ - private def getKeyHashCode[K, C](kc: (K, C)): Int = { - if (kc._1 == null) 0 else kc._1.hashCode() + private def hash[T](obj: T): Int = { + if (obj == null) 0 else obj.hashCode() } /** - * A comparator for (key, combiner) pairs based on their key hash codes. + * A comparator which sorts arbitrary keys based on their hash codes. */ - private class KCComparator[K, C] extends Comparator[(K, C)] { - def compare(kc1: (K, C), kc2: (K, C)): Int = { - val hash1 = getKeyHashCode(kc1) - val hash2 = getKeyHashCode(kc2) + private class HashComparator[K] extends Comparator[K] { + def compare(key1: K, key2: K): Int = { + val hash1 = hash(key1) + val hash2 = hash(key2) if (hash1 < hash2) -1 else if (hash1 == hash2) 0 else 1 } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index b84eb65c62bc7..7e76d060d6000 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -36,7 +36,7 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: _array(index) } - def +=(value: V) { + def +=(value: V): Unit = { if (_numElements == _array.length) { resize(_array.length * 2) } @@ -50,6 +50,19 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: def size: Int = _numElements + def iterator: Iterator[V] = new Iterator[V] { + var index = 0 + override def hasNext: Boolean = index < _numElements + override def next(): V = { + if (!hasNext) { + throw new NoSuchElementException + } + val value = _array(index) + index += 1 + value + } + } + /** Gets the underlying array backing this vector. */ def array: Array[V] = _array diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTracker.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTracker.scala new file mode 100644 index 0000000000000..3eb1010dc1e8d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTracker.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable + +import org.apache.spark.util.SizeEstimator + +/** + * A general interface for collections to keep track of their estimated sizes in bytes. + * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, + * as each call to SizeEstimator is somewhat expensive (order of a few milliseconds). + */ +private[spark] trait SizeTracker { + + import SizeTracker._ + + /** + * Controls the base of the exponential which governs the rate of sampling. + * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. + */ + private val SAMPLE_GROWTH_RATE = 1.1 + + /** Samples taken since last resetSamples(). Only the last two are kept for extrapolation. */ + private val samples = new mutable.Queue[Sample] + + /** The average number of bytes per update between our last two samples. */ + private var bytesPerUpdate: Double = _ + + /** Total number of insertions and updates into the map since the last resetSamples(). */ + private var numUpdates: Long = _ + + /** The value of 'numUpdates' at which we will take our next sample. */ + private var nextSampleNum: Long = _ + + resetSamples() + + /** + * Reset samples collected so far. + * This should be called after the collection undergoes a dramatic change in size. + */ + protected def resetSamples(): Unit = { + numUpdates = 1 + nextSampleNum = 1 + samples.clear() + takeSample() + } + + /** + * Callback to be invoked after every update. + */ + protected def afterUpdate(): Unit = { + numUpdates += 1 + if (nextSampleNum == numUpdates) { + takeSample() + } + } + + /** + * Take a new sample of the current collection's size. + */ + private def takeSample(): Unit = { + samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates)) + // Only use the last two samples to extrapolate + if (samples.size > 2) { + samples.dequeue() + } + val bytesDelta = samples.toList.reverse match { + case latest :: previous :: tail => + (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates) + // If fewer than 2 samples, assume no change + case _ => 0 + } + bytesPerUpdate = math.max(0, bytesDelta) + nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong + } + + /** + * Estimate the current size of the collection in bytes. O(1) time. + */ + def estimateSize(): Long = { + assert(samples.nonEmpty) + val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates) + (samples.last.size + extrapolatedDelta).toLong + } +} + +private object SizeTracker { + case class Sample(size: Long, numUpdates: Long) +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala index 204330dad48b9..de61e1d17fe10 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -17,85 +17,24 @@ package org.apache.spark.util.collection -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.util.SizeEstimator -import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample - /** - * Append-only map that keeps track of its estimated size in bytes. - * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, - * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds). + * An append-only map that keeps track of its estimated size in bytes. */ -private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] { - - /** - * Controls the base of the exponential which governs the rate of sampling. - * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. - */ - private val SAMPLE_GROWTH_RATE = 1.1 - - /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */ - private val samples = new ArrayBuffer[Sample]() - - /** Total number of insertions and updates into the map since the last resetSamples(). */ - private var numUpdates: Long = _ - - /** The value of 'numUpdates' at which we will take our next sample. */ - private var nextSampleNum: Long = _ - - /** The average number of bytes per update between our last two samples. */ - private var bytesPerUpdate: Double = _ - - resetSamples() - - /** Called after the map grows in size, as this can be a dramatic change for small objects. */ - def resetSamples() { - numUpdates = 1 - nextSampleNum = 1 - samples.clear() - takeSample() - } +private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker { override def update(key: K, value: V): Unit = { super.update(key, value) - numUpdates += 1 - if (nextSampleNum == numUpdates) { takeSample() } + super.afterUpdate() } override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { val newValue = super.changeValue(key, updateFunc) - numUpdates += 1 - if (nextSampleNum == numUpdates) { takeSample() } + super.afterUpdate() newValue } - /** Takes a new sample of the current map's size. */ - def takeSample() { - samples += Sample(SizeEstimator.estimate(this), numUpdates) - // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change. - bytesPerUpdate = math.max(0, samples.toSeq.reverse match { - case latest :: previous :: tail => - (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates) - case _ => - 0 - }) - nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong - } - - override protected def growTable() { + override protected def growTable(): Unit = { super.growTable() resetSamples() } - - /** Estimates the current size of the map in bytes. O(1) time. */ - def estimateSize(): Long = { - assert(samples.nonEmpty) - val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates) - (samples.last.size + extrapolatedDelta).toLong - } -} - -private object SizeTrackingAppendOnlyMap { - case class Sample(size: Long, numUpdates: Long) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala similarity index 58% rename from core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala rename to core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala index c4a8996c0b9a9..65a7b4e0d497b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala @@ -15,22 +15,32 @@ * limitations under the License. */ -package org.apache.spark.ui.jobs +package org.apache.spark.util.collection -import org.apache.spark.annotation.DeveloperApi +import scala.reflect.ClassTag /** - * :: DeveloperApi :: - * Class for reporting aggregated metrics for each executor in stage UI. + * An append-only buffer that keeps track of its estimated size in bytes. */ -@DeveloperApi -class ExecutorSummary { - var taskTime : Long = 0 - var failedTasks : Int = 0 - var succeededTasks : Int = 0 - var inputBytes: Long = 0 - var shuffleRead : Long = 0 - var shuffleWrite : Long = 0 - var memoryBytesSpilled : Long = 0 - var diskBytesSpilled : Long = 0 +private[spark] class SizeTrackingVector[T: ClassTag] + extends PrimitiveVector[T] + with SizeTracker { + + override def +=(value: T): Unit = { + super.+=(value) + super.afterUpdate() + } + + override def resize(newLength: Int): PrimitiveVector[T] = { + super.resize(newLength) + resetSamples() + this + } + + /** + * Return a trimmed version of the underlying array. + */ + def toArray: Array[T] = { + super.iterator.toArray + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala new file mode 100644 index 0000000000000..ac1528969f0be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + +/** + * Abstraction for sorting an arbitrary input buffer of data. This interface requires determining + * the sort key for a given element index, as well as swapping elements and moving data from one + * buffer to another. + * + * Example format: an array of numbers, where each element is also the key. + * See [[KVArraySortDataFormat]] for a more exciting format. + * + * This trait extends Any to ensure it is universal (and thus compiled to a Java interface). + * + * @tparam K Type of the sort key of each element + * @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]). + */ +// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity. +private[spark] trait SortDataFormat[K, Buffer] extends Any { + /** Return the sort key for the element at the given index. */ + protected def getKey(data: Buffer, pos: Int): K + + /** Swap two elements. */ + protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit + + /** Copy a single element from src(srcPos) to dst(dstPos). */ + protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit + + /** + * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. + * Overlapping ranges are allowed. + */ + protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit + + /** + * Allocates a Buffer that can hold up to 'length' elements. + * All elements of the buffer should be considered invalid until data is explicitly copied in. + */ + protected def allocate(length: Int): Buffer +} + +/** + * Supports sorting an array of key-value pairs where the elements of the array alternate between + * keys and values, as used in [[AppendOnlyMap]]. + * + * @tparam K Type of the sort key of each element + * @tparam T Type of the Array we're sorting. Typically this must extend AnyRef, to support cases + * when the keys and values are not the same type. + */ +private[spark] +class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] { + + override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] + + override protected def swap(data: Array[T], pos0: Int, pos1: Int) { + val tmpKey = data(2 * pos0) + val tmpVal = data(2 * pos0 + 1) + data(2 * pos0) = data(2 * pos1) + data(2 * pos0 + 1) = data(2 * pos1 + 1) + data(2 * pos1) = tmpKey + data(2 * pos1 + 1) = tmpVal + } + + override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { + dst(2 * dstPos) = src(2 * srcPos) + dst(2 * dstPos + 1) = src(2 * srcPos + 1) + } + + override protected def copyRange(src: Array[T], srcPos: Int, + dst: Array[T], dstPos: Int, length: Int) { + System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length) + } + + override protected def allocate(length: Int): Array[T] = { + new Array[T](2 * length) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index f46ef2d962e98..467fb84dd42e0 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -17,8 +17,54 @@ package org.apache.spark.util.random +import scala.reflect.ClassTag +import scala.util.Random + private[spark] object SamplingUtils { + /** + * Reservoir sampling implementation that also returns the input size. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return (samples, input size) + */ + def reservoirSampleAndCount[T: ClassTag]( + input: Iterator[T], + k: Int, + seed: Long = Random.nextLong()) + : (Array[T], Int) = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item + i += 1 + } + + // If we have consumed all the elements, return them. Otherwise do the replacement. + if (i < k) { + // If input size < k, trim the array to return only an array of input size. + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + (trimReservoir, i) + } else { + // If input size > k, continue the sampling process. + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + val replacementIndex = rand.nextInt(i) + if (replacementIndex < k) { + reservoir(replacementIndex) = item + } + i += 1 + } + (reservoir, i) + } + } + /** * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of * the time. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index afd575aa6c631..e8bd65f8e4507 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -119,8 +119,7 @@ public void intersection() { JavaRDD intersections = s1.intersection(s2); Assert.assertEquals(3, intersections.count()); - List list = new ArrayList(); - JavaRDD empty = sc.parallelize(list); + JavaRDD empty = sc.emptyRDD(); JavaRDD emptyIntersection = empty.intersection(s2); Assert.assertEquals(0, emptyIntersection.count()); @@ -185,6 +184,12 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @Test + public void emptyRDD() { + JavaRDD rdd = sc.emptyRDD(); + Assert.assertEquals("Empty RDD shouldn't have any values", 0, rdd.count()); + } + @Test public void sortBy() { List> pairs = new ArrayList>(); diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 7f5d0b061e8b0..9c5f394d3899d 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark -import scala.collection.mutable.ArrayBuffer - import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.mock.EasyMockSugar @@ -52,22 +50,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } test("get uncached rdd") { - expecting { - blockManager.get(RDDBlockId(0, 0)).andReturn(None) - blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, - true).andStubReturn(Seq[(BlockId, BlockStatus)]()) - } - - whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } + // Do not mock this test, because attempting to match Array[Any], which is not covariant, + // in blockManager.put is a losing battle. You have been warned. + blockManager = sc.env.blockManager + cacheManager = sc.env.cacheManager + val context = new TaskContext(0, 0, 0) + val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) + assert(computeValue.toList === List(1, 2, 3, 4)) + assert(getValue.isDefined, "Block cached from getOrCompute is not found!") + assert(getValue.get.data.toList === List(1, 2, 3, 4)) } test("get cached rdd") { expecting { - val result = new BlockResult(ArrayBuffer(5, 6, 7).iterator, DataReadMethod.Memory, 12) + val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) blockManager.get(RDDBlockId(0, 0)).andReturn(Some(result)) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index fc00458083a33..d1cb2d9d3a53b 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -156,15 +156,20 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("CoGroupedRDD") { val longLineageRDD1 = generateFatPairRDD() + + // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD + val seqCollectFunc = (rdd: RDD[(Int, Array[Iterable[Int]])]) => + rdd.map{case (p, a) => (p, a.toSeq)}.collect(): Any + testRDD(rdd => { CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }) + }, seqCollectFunc) val longLineageRDD2 = generateFatPairRDD() testRDDPartitions(rdd => { CheckpointSuite.cogroup( longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }) + }, seqCollectFunc) } test("ZippedPartitionsRDD") { @@ -235,12 +240,19 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(rdd.partitions.size === 0) } + def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + /** * Test checkpointing of the RDD generated by the given operation. It tests whether the * serialized size of the RDD is reduce after checkpointing or not. This function should be called * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U]) { + def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U], + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -258,13 +270,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() - val result = operatedRDD.collect() + val result = collectFunc(operatedRDD) operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the checkpoint file has been created - assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) @@ -279,7 +291,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(operatedRDD.partitions.length === numPartitions) // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) + assert(collectFunc(operatedRDD) === result) // Test whether serialized size of the RDD has reduced. logInfo("Size of " + rddType + @@ -289,7 +301,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { "Size of " + rddType + " did not reduce after checkpointing " + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) - } /** @@ -300,8 +311,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * This function should be called only those RDD whose partitions refer to parent RDD's * partitions (i.e., do not call it on simple RDD like MappedRDD). * + * @param op an operation to run on the RDD + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U]) { + def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U], + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -316,13 +331,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one - val result = operatedRDD.collect() // force checkpointing + val result = collectFunc(operatedRDD) // force checkpointing operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) + assert(collectFunc(operatedRDD) === result) // Test whether serialized size of the partitions has reduced logInfo("Size of partitions of " + rddType + @@ -436,7 +451,7 @@ object CheckpointSuite { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part - ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index e755d2e309398..2229e6acc425d 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -104,8 +104,9 @@ class FailureSuite extends FunSuite with LocalSparkContext { results.collect() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException") || - thrown.getCause.getClass === classOf[NotSerializableException]) + assert(thrown.getMessage.contains("serializable") || + thrown.getCause.getClass === classOf[NotSerializableException], + "Exception does not contain \"serializable\": " + thrown.getMessage) FailureSuiteState.clear() } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7c30626a0c421..4658a08064280 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -91,6 +91,17 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet } } + test("RangePartitioner for keys that are not Comparable (but with Ordering)") { + // Row does not extend Comparable, but has an implicit Ordering defined. + implicit object RowOrdering extends Ordering[Row] { + override def compare(x: Row, y: Row) = x.value - y.value + } + + val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) + val partitioner = new RangePartitioner(1500, rdd) + partitioner.getPartition(Row(100)) + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) @@ -177,3 +188,6 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet // Add other tests here for classes that should be able to handle empty partitions correctly } } + + +private sealed case class Row(value: Int) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c4f2f7e34f4d5..eae67c7747e82 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -176,7 +176,9 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() + val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)) + .map(p => (p._1, p._2.map(_.toArray))) + .collectAsMap() assert(results(1)(0).length === 3) assert(results(1)(0).contains(1)) @@ -240,7 +242,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) + assert(thrown.getMessage.toLowerCase.contains("serializable")) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 1fde4badda949..fb18c3ebfe46f 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -70,7 +70,7 @@ package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite - val curCallSite = sc.getCallSite().short // note: 2 lines after definition of "rdd" + val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 565c53e9529ff..f497a5e0a14f0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -120,6 +120,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -139,6 +140,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") + sysProps("spark.shuffle.spill") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") } @@ -156,6 +158,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -176,6 +179,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone cluster mode") { @@ -186,6 +190,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--supervise", "--driver-memory", "4g", "--driver-cores", "5", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -195,9 +200,10 @@ class SparkSubmitSuite extends FunSuite with Matchers { childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") mainClass should be ("org.apache.spark.deploy.Client") classpath should have size (0) - sysProps should have size (2) + sysProps should have size (3) sysProps.keys should contain ("spark.jars") sysProps.keys should contain ("SPARK_SUBMIT") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone client mode") { @@ -208,6 +214,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -218,6 +225,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("handles mesos client mode") { @@ -228,6 +236,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -238,6 +247,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("launch simple application with spark-submit") { diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 68a0ea36aa545..3f882a724b047 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -46,7 +46,13 @@ class CompressionCodecSuite extends FunSuite { test("default compression codec") { val codec = CompressionCodec.createCodec(conf) - assert(codec.getClass === classOf[LZFCompressionCodec]) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } + + test("lz4 compression codec") { + val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) + assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 5dd8de319a654..a0483886f8db3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) val sampler = new MockSampler - val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) + val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L) assert(sample.distinct().count == 2, "Seeds must be different.") } @@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { // We want to make sure there are no concurrency issues. val rdd = sc.parallelize(0 until 111, 10) for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) { - val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler) + val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true) sampled.zip(sampled).count() } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0f9cbe213ea17..6654ec2d7c656 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -351,6 +351,20 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + // Test for SPARK-2412 -- ensure that the second pass of the algorithm does not throw an exception + test("coalesced RDDs with locality, fail first pass") { + val initialPartitions = 1000 + val targetLen = 50 + val couponCount = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt // = 492 + + val blocks = (1 to initialPartitions).map { i => + (i, List(if (i > couponCount) "m2" else "m1")) + } + val data = sc.makeRDD(blocks) + val coalesced = data.coalesce(targetLen) + assert(coalesced.partitions.length == targetLen) + } + test("zipped RDDs") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val zipped = nums.zip(nums.map(_ + 1.0)) @@ -379,6 +393,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("mapWith") { import java.util.Random val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val randoms = ones.mapWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => prng.nextDouble * t}.collect() @@ -397,6 +412,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("flatMapWith") { import java.util.Random val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val randoms = ones.flatMapWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => @@ -418,6 +434,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("filterWith") { import java.util.Random val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val sample = ints.filterWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => prng.nextInt(3) == 0}. @@ -506,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } + test("sample preserves partitioner") { + val partitioner = new HashPartitioner(2) + val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner) + for (withReplacement <- Seq(true, false)) { + val sampled = rdd.sample(withReplacement, 1.0) + assert(sampled.partitioner === rdd.partitioner) + } + } + test("takeSample") { val n = 1000000 val data = sc.parallelize(1 to n, 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9f498d579a095..9021662bcf712 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -37,6 +37,29 @@ class BuggyDAGEventProcessActor extends Actor { } } +/** + * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + */ +class MyRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 until numPartitions).map(i => new Partition { + override def index = i + }).toArray + override def getPreferredLocations(split: Partition): Seq[String] = + if (locations.isDefinedAt(split.index)) + locations(split.index) + else + Nil + override def toString: String = "DAGSchedulerSuiteRDD " + id +} + class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike @@ -148,34 +171,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. * This is a pair RDD type so it can always be used in ShuffleDependencies. */ - type MyRDD = RDD[(Int, Int)] - - /** - * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and - * preferredLocations (if any) that are passed to them. They are deliberately not executable - * so we can test that DAGScheduler does not try to execute RDDs locally. - */ - private def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]], - locations: Seq[Seq[String]] = Nil - ): MyRDD = { - val maxPartition = numPartitions - 1 - val newRDD = new MyRDD(sc, dependencies) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new RuntimeException("should not be reached") - override def getPartitions = (0 to maxPartition).map(i => new Partition { - override def index = i - }).toArray - override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil - override def toString: String = "DAGSchedulerSuiteRDD " + id - } - newRDD - } + type PairOfIntsRDD = RDD[(Int, Int)] /** * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting @@ -234,19 +230,19 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def taskSucceeded(partition: Int, value: Any) = numResults += 1 override def jobFailed(exception: Exception) = throw exception } - submit(makeRdd(0, Nil), Array(), listener = fakeListener) + submit(new MyRDD(sc, 0, Nil), Array(), listener = fakeListener) assert(numResults === 0) } test("run trivial job") { - submit(makeRdd(1, Nil), Array(0)) + submit(new MyRDD(sc, 1, Nil), Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } test("local job") { - val rdd = new MyRDD(sc, Nil) { + val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = Array(42 -> 0).iterator override def getPartitions = Array( new Partition { override def index = 0 } ) @@ -260,7 +256,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("local job oom") { - val rdd = new MyRDD(sc, Nil) { + val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new java.lang.OutOfMemoryError("test local job oom") override def getPartitions = Array( new Partition { override def index = 0 } ) @@ -274,8 +270,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("run trivial job w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + val baseRdd = new MyRDD(sc, 1, Nil) + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -283,8 +279,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("cache location preferences w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + val baseRdd = new MyRDD(sc, 1, Nil) + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) cacheLocations(baseRdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) submit(finalRdd, Array(0)) @@ -295,8 +291,22 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("unserializable task") { + val unserializableRdd = new MyRDD(sc, 1, Nil) { + class UnserializableClass + val unserializable = new UnserializableClass + } + submit(unserializableRdd, Array(0)) + assert(failure.getMessage.startsWith( + "Job aborted due to stage failure: Task not serializable:")) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty + } + test("trivial job failure") { - submit(makeRdd(1, Nil), Array(0)) + submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) @@ -306,7 +316,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("trivial job cancellation") { - val rdd = makeRdd(1, Nil) + val rdd = new MyRDD(sc, 1, Nil) val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled ") @@ -347,8 +357,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) - val rdd = makeRdd(1, Nil) - val jobId = submit(rdd, Array(0)) + val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. assert(failure === null) @@ -364,10 +373,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("run trivial shuffle") { - val shuffleMapRdd = makeRdd(2, Nil) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), @@ -380,10 +389,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("run trivial shuffle with fetch failure") { - val shuffleMapRdd = makeRdd(2, Nil) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), @@ -406,10 +415,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("ignore late map task completions") { - val shuffleMapRdd = makeRdd(2, Nil) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch @@ -435,9 +444,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("run shuffle with map stage failure") { - val shuffleMapRdd = makeRdd(2, Nil) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val reduceRdd = makeRdd(2, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) // Fail the map stage. This should cause the entire job to fail. @@ -472,13 +481,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F * without shuffleMapRdd1. */ test("failure of stage used by two jobs") { - val shuffleMapRdd1 = makeRdd(2, Nil) + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null) - val shuffleMapRdd2 = makeRdd(2, Nil) + val shuffleMapRdd2 = new MyRDD(sc, 2, Nil) val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null) - val reduceRdd1 = makeRdd(2, List(shuffleDep1)) - val reduceRdd2 = makeRdd(2, List(shuffleDep1, shuffleDep2)) + val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1)) + val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2)) // We need to make our own listeners for this test, since by default submit uses the same // listener for all jobs, and here we want to capture the failure for each job separately. @@ -511,10 +520,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("run trivial shuffle with out-of-band failure and retry") { - val shuffleMapRdd = makeRdd(2, Nil) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) // blockManagerMaster.removeExecutor("exec-hostA") // pretend we were told hostA went away @@ -534,11 +543,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("recursive shuffle failures") { - val shuffleOneRdd = makeRdd(2, Nil) + val shuffleOneRdd = new MyRDD(sc, 2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) // have the first stage complete normally complete(taskSets(0), Seq( @@ -563,11 +572,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } test("cached post-shuffle") { - val shuffleOneRdd = makeRdd(2, Nil) + val shuffleOneRdd = new MyRDD(sc, 2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) @@ -677,15 +686,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F BlockManagerId("exec-" + host, host, 12345, 0) private def assertDataStructuresEmpty = { - assert(scheduler.pendingTasks.isEmpty) assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) assert(scheduler.jobIdToStageIds.isEmpty) - assert(scheduler.stageIdToJobIds.isEmpty) assert(scheduler.stageIdToStage.isEmpty) - assert(scheduler.stageToInfos.isEmpty) - assert(scheduler.resultStageToJob.isEmpty) assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 71f48e295ecca..3b0b8e2f68c97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -258,8 +258,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be > (0) - sm.localBlocksFetched should be > (0) + sm.totalBlocksFetched should be (128) + sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ff2a487005c4..c52368b5514db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -54,6 +54,23 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) } } +// Get the rack for a given host +object FakeRackUtil { + private val hostToRack = new mutable.HashMap[String, String]() + + def cleanUp() { + hostToRack.clear() + } + + def assignHostToRack(host: String, rack: String) { + hostToRack(host) = rack + } + + def getRackForHost(host: String) = { + hostToRack.get(host) + } +} + /** * A mock TaskSchedulerImpl implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with @@ -69,6 +86,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val taskSetsFailed = new ArrayBuffer[String] val executors = new mutable.HashMap[String, String] ++ liveExecutors + for ((execId, host) <- liveExecutors; rack <- getRackForHost(host)) { + hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host + } dagScheduler = new FakeDAGScheduler(sc, this) @@ -82,7 +102,12 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex def addExecutor(execId: String, host: String) { executors.put(execId, host) + for (rack <- getRackForHost(host)) { + hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host + } } + + override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value) } /** @@ -419,6 +444,9 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("new executors get added") { + // Assign host2 to rack2 + FakeRackUtil.cleanUp() + FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) val taskSet = FakeTask.createTaskSet(4, @@ -444,8 +472,41 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { manager.executorAdded() // No-pref list now only contains task 3 assert(manager.pendingTasksWithNoPrefs.size === 1) - // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + FakeRackUtil.cleanUp() + } + + test("test RACK_LOCAL tasks") { + FakeRackUtil.cleanUp() + // Assign host1 to rack1 + FakeRackUtil.assignHostToRack("host1", "rack1") + // Assign host2 to rack1 + FakeRackUtil.assignHostToRack("host2", "rack1") + // Assign host3 to rack2 + FakeRackUtil.assignHostToRack("host3", "rack2") + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execA"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + // Set allowed locality to ANY + clock.advance(LOCALITY_WAIT * 3) + // Offer host3 + // No task is scheduled if we restrict locality to RACK_LOCAL + assert(manager.resourceOffer("execC", "host3", RACK_LOCAL) === None) + // Task 0 can be scheduled with ANY + assert(manager.resourceOffer("execC", "host3", ANY).get.index === 0) + // Offer host2 + // Task 1 can be scheduled with RACK_LOCAL + assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) + FakeRackUtil.cleanUp() } test("do not emit warning when serialized task is small") { diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 5d15a68ac7e4f..aad6599589420 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -15,15 +15,12 @@ * limitations under the License. */ -package org.apache.spark.serializer; - -import java.io.NotSerializableException +package org.apache.spark.serializer import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkException} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkException -import org.apache.spark.SharedSparkContext /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { @@ -38,52 +35,50 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex test("throws expected serialization exceptions on actions") { val (data, uc) = fixture - val ex = intercept[SparkException] { - data.map(uc.op(_)).count + data.map(uc.op(_)).count() } - assert(ex.getMessage.contains("Task not serializable")) } // There is probably a cleaner way to eliminate boilerplate here, but we're // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each - + for (transformation <- - Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, - "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _, + Map("map" -> xmap _, + "flatMap" -> xflatMap _, + "filter" -> xfilter _, + "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _, - "filterWith" -> xfilterWith _)) { + "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) { val (name, xf) = transformation - + test(s"$name transformations throw proactive serialization exceptions") { val (data, uc) = fixture - val ex = intercept[SparkException] { xf(data, uc) } - assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } - + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y=>uc.op(y)) - private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y=>Seq(uc.op(y))) + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y=>uc.pred(y)) - private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y=>uc.op(y))) + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) + private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 23cb6905bfdeb..dd4fd535d3577 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.SpanSugar._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} @@ -43,6 +43,7 @@ import scala.language.postfixOps class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { + private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null @@ -61,21 +62,29 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { + new BlockManager( + name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker) + } + before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf, - securityManager = securityMgr) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) - - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") conf.set("spark.storage.disableBlockManagerHeartBeat", "true") + conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.storage.unrollFraction", "0.4") + conf.set("spark.storage.unrollMemoryThreshold", "512") + + master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf) + val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() } @@ -138,11 +147,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(20000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -169,10 +177,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000, "exec1") + store2 = makeBlockManager(2000, "exec2") val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -187,11 +193,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(20000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) @@ -200,8 +205,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Checking whether blocks are in memory and memory size val memStatus = master.getMemoryStatus.head._2 - assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") - assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") + assert(memStatus._1 == 20000L, "total memory " + memStatus._1 + " should equal 20000") + assert(memStatus._2 <= 12000L, "remaining memory " + memStatus._2 + " should <= 12000") assert(store.getSingle("a1-to-remove").isDefined, "a1 was not in store") assert(store.getSingle("a2-to-remove").isDefined, "a2 was not in store") assert(store.getSingle("a3-to-remove").isDefined, "a3 was not in store") @@ -230,17 +235,16 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { val memStatus = master.getMemoryStatus.head._2 - memStatus._1 should equal (2000L) - memStatus._2 should equal (2000L) + memStatus._1 should equal (20000L) + memStatus._2 should equal (20000L) } } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(20000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory. store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) @@ -270,11 +274,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("removing broadcast") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val driverStore = store - val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + val executorStore = makeBlockManager(2000, "executor") val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -343,8 +345,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -380,8 +381,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -390,7 +390,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { @@ -418,19 +418,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("correct BlockResult returned from get() calls") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr, - mapOutputTracker) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list1ForSizeEstimate = new ArrayBuffer[Any] - list1ForSizeEstimate ++= list1.iterator - val list1SizeEstimate = SizeEstimator.estimate(list1ForSizeEstimate) - val list2 = List(new Array[Byte](50), new Array[Byte](100), new Array[Byte](150)) - val list2ForSizeEstimate = new ArrayBuffer[Any] - list2ForSizeEstimate ++= list2.iterator - val list2SizeEstimate = SizeEstimator.estimate(list2ForSizeEstimate) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store = makeBlockManager(12000) + val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list2 = List(new Array[Byte](500), new Array[Byte](1000), new Array[Byte](1500)) + val list1SizeEstimate = SizeEstimator.estimate(list1.iterator.toArray) + val list2SizeEstimate = SizeEstimator.estimate(list2.iterator.toArray) + store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) @@ -451,11 +446,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) @@ -471,11 +465,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) @@ -491,11 +484,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) @@ -511,11 +503,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store = makeBlockManager(12000) + store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 2), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // At this point rdd_1_1 should've replaced rdd_0_1 assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store") assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") @@ -523,8 +514,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Do a get() on rdd_0_2 so that it is the most recently used item assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // when we try to add rdd_0_4. assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store") @@ -538,8 +529,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) if (tachyonUnitTestEnabled) { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -555,8 +545,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -569,11 +558,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) @@ -585,11 +573,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) @@ -601,11 +588,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) @@ -617,11 +603,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) @@ -633,12 +618,11 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - val a4 = new Array[Byte](400) + store = makeBlockManager(12000) + val a1 = new Array[Byte](4000) + val a2 = new Array[Byte](4000) + val a3 = new Array[Byte](4000) + val a4 = new Array[Byte](4000) // First store a1 and a2, both in memory, and a3, on disk only store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) @@ -656,14 +640,13 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store = makeBlockManager(12000) + val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) + store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) assert(store.get("list3").isDefined, "list3 was not in store") @@ -672,7 +655,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) assert(store.get("list2").isDefined, "list2 was not in store") @@ -681,16 +664,15 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - val list4 = List(new Array[Byte](200), new Array[Byte](200)) + store = makeBlockManager(12000) + val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) + val list4 = List(new Array[Byte](2000), new Array[Byte](2000)) // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIterator("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val listForSizeEstimate = new ArrayBuffer[Any] listForSizeEstimate ++= list1.iterator val listSize = SizeEstimator.estimate(listForSizeEstimate) @@ -708,7 +690,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.get("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + store.putIterator("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) @@ -731,11 +713,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf, - securityMgr, mapOutputTracker) - store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) + store = makeBlockManager(5000) + store.putSingle("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") - store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) + store.putSingle("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") assert(store.getSingle("a2").isDefined, "a2 was not in store") } @@ -743,8 +724,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(20000, "exec1") store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -752,52 +732,46 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, + store = makeBlockManager(20000, "exec2") + store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 10000, "shuffle_0_0_0 was compressed") store.stop() store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, + store = makeBlockManager(20000, "exec3") + store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 1000, "broadcast_0 was not compressed") store.stop() store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") + store = makeBlockManager(20000, "exec4") + store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 10000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") + store = makeBlockManager(20000, "exec5") + store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) <= 1000, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") + store = makeBlockManager(20000, "exec6") + store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) >= 10000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) - store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) - assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") + store = makeBlockManager(20000, "exec7") + store.putSingle("other_block", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) + assert(store.memoryStore.getSize("other_block") >= 10000, "other_block was compressed") store.stop() store = null } finally { @@ -871,30 +845,29 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(Arrays.equals(mappedAsArray, bytes)) assert(Arrays.equals(notMappedAsArray, bytes)) } - + test("updated block statuses") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val list = List.fill(2)(new Array[Byte](200)) - val bigList = List.fill(8)(new Array[Byte](200)) + store = makeBlockManager(12000) + val list = List.fill(2)(new Array[Byte](2000)) + val bigList = List.fill(8)(new Array[Byte](2000)) // 1 updated block (i.e. list1) val updatedBlocks1 = - store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(updatedBlocks1.size === 1) assert(updatedBlocks1.head._1 === TestBlockId("list1")) assert(updatedBlocks1.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 1 updated block (i.e. list2) val updatedBlocks2 = - store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) assert(updatedBlocks2.size === 1) assert(updatedBlocks2.head._1 === TestBlockId("list2")) assert(updatedBlocks2.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 2 updated blocks - list1 is kicked out of memory while list3 is added val updatedBlocks3 = - store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(updatedBlocks3.size === 2) updatedBlocks3.foreach { case (id, status) => id match { @@ -903,11 +876,11 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter case _ => fail("Updated block is neither list1 nor list3") } } - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.memoryStore.contains("list3"), "list3 was not in memory store") // 2 updated blocks - list2 is kicked out of memory (but put on disk) while list4 is added val updatedBlocks4 = - store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(updatedBlocks4.size === 2) updatedBlocks4.foreach { case (id, status) => id match { @@ -916,26 +889,37 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter case _ => fail("Updated block is neither list2 nor list4") } } - assert(store.get("list4").isDefined, "list4 was not in store") + assert(store.diskStore.contains("list2"), "list2 was not in disk store") + assert(store.memoryStore.contains("list4"), "list4 was not in memory store") - // No updated blocks - nothing is kicked out of memory because list5 is too big to be added + // No updated blocks - list5 is too big to fit in store and nothing is kicked out val updatedBlocks5 = - store.put("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(updatedBlocks5.size === 0) - assert(store.get("list2").isDefined, "list2 was not in store") - assert(store.get("list4").isDefined, "list4 was not in store") - assert(!store.get("list5").isDefined, "list5 was in store") + + // memory store contains only list3 and list4 + assert(!store.memoryStore.contains("list1"), "list1 was in memory store") + assert(!store.memoryStore.contains("list2"), "list2 was in memory store") + assert(store.memoryStore.contains("list3"), "list3 was not in memory store") + assert(store.memoryStore.contains("list4"), "list4 was not in memory store") + assert(!store.memoryStore.contains("list5"), "list5 was in memory store") + + // disk store contains only list2 + assert(!store.diskStore.contains("list1"), "list1 was in disk store") + assert(store.diskStore.contains("list2"), "list2 was not in disk store") + assert(!store.diskStore.contains("list3"), "list3 was in disk store") + assert(!store.diskStore.contains("list4"), "list4 was in disk store") + assert(!store.diskStore.contains("list5"), "list5 was in disk store") } test("query block statuses") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val list = List.fill(2)(new Array[Byte](200)) + store = makeBlockManager(12000) + val list = List.fill(2)(new Array[Byte](2000)) // Tell master. By LRU, only list2 and list3 remains. - store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getLocations("list1").size === 0) @@ -949,9 +933,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. - store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) - store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIterator("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result @@ -968,23 +952,22 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("get matching blocks") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - val list = List.fill(2)(new Array[Byte](10)) + store = makeBlockManager(12000) + val list = List.fill(2)(new Array[Byte](100)) // insert some blocks - store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) // insert some more blocks - store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) @@ -992,7 +975,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => - store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true @@ -1002,17 +985,240 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) - store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store = makeBlockManager(12000) + store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") // According to the same-RDD rule, rdd_1_0 should be replaced here. - store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // rdd_1_0 should have been replaced, even it's not least recently used. assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store") assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + + test("reserve/release unroll memory") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + assert(memoryStore.currentUnrollMemory === 0) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Reserve + memoryStore.reserveUnrollMemoryForThisThread(100) + assert(memoryStore.currentUnrollMemoryForThisThread === 100) + memoryStore.reserveUnrollMemoryForThisThread(200) + assert(memoryStore.currentUnrollMemoryForThisThread === 300) + memoryStore.reserveUnrollMemoryForThisThread(500) + assert(memoryStore.currentUnrollMemoryForThisThread === 800) + memoryStore.reserveUnrollMemoryForThisThread(1000000) + assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + // Release + memoryStore.releaseUnrollMemoryForThisThread(100) + assert(memoryStore.currentUnrollMemoryForThisThread === 700) + memoryStore.releaseUnrollMemoryForThisThread(100) + assert(memoryStore.currentUnrollMemoryForThisThread === 600) + // Reserve again + memoryStore.reserveUnrollMemoryForThisThread(4400) + assert(memoryStore.currentUnrollMemoryForThisThread === 5000) + memoryStore.reserveUnrollMemoryForThisThread(20000) + assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + // Release again + memoryStore.releaseUnrollMemoryForThisThread(1000) + assert(memoryStore.currentUnrollMemoryForThisThread === 4000) + memoryStore.releaseUnrollMemoryForThisThread() // release all + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + } + + /** + * Verify the result of MemoryStore#unrollSafely is as expected. + */ + private def verifyUnroll( + expected: Iterator[Any], + result: Either[Array[Any], Iterator[Any]], + shouldBeArray: Boolean): Unit = { + val actual: Iterator[Any] = result match { + case Left(arr: Array[Any]) => + assert(shouldBeArray, "expected iterator from unroll!") + arr.iterator + case Right(it: Iterator[Any]) => + assert(!shouldBeArray, "expected array from unroll!") + it + case _ => + fail("unroll returned neither an iterator nor an array...") + } + expected.zip(actual).foreach { case (e, a) => + assert(e === a, "unroll did not return original values!") + } + } + + test("safely unroll blocks") { + store = makeBlockManager(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + val memoryStore = store.memoryStore + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Unroll with all the space in the world. This should succeed and return an array. + var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) + verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Unroll with not enough space. This should succeed after kicking out someBlock1. + store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) + unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) + verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(droppedBlocks.size === 1) + assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) + droppedBlocks.clear() + + // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = + // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. + // In the mean time, however, we kicked out someBlock2 before giving up. + store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) + unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) + verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) + assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(droppedBlocks.size === 1) + assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) + droppedBlocks.clear() + } + + test("safely unroll blocks through putIterator") { + store = makeBlockManager(12000) + val memOnly = StorageLevel.MEMORY_ONLY + val memoryStore = store.memoryStore + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Unroll with plenty of space. This should succeed and cache both blocks. + val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) + val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) + assert(memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(result1.size > 0) // unroll was successful + assert(result2.size > 0) + assert(result1.data.isLeft) // unroll did not drop this block to disk + assert(result2.data.isLeft) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Re-put these two blocks so block manager knows about them too. Otherwise, block manager + // would not know how to drop them from memory later. + memoryStore.remove("b1") + memoryStore.remove("b2") + store.putIterator("b1", smallIterator, memOnly) + store.putIterator("b2", smallIterator, memOnly) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) + assert(result3.size > 0) + assert(result3.data.isLeft) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.remove("b3") + store.putIterator("b3", smallIterator, memOnly) + + // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true) + assert(result4.size === 0) // unroll was unsuccessful + assert(result4.data.isLeft) + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + } + + /** + * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK. + */ + test("safely unroll blocks through putIterator (disk)") { + store = makeBlockManager(12000) + val memAndDisk = StorageLevel.MEMORY_AND_DISK + val memoryStore = store.memoryStore + val diskStore = store.diskStore + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + store.putIterator("b1", smallIterator, memAndDisk) + store.putIterator("b2", smallIterator, memAndDisk) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + // Memory store should contain b2 and b3, while disk store should contain only b1 + val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk, returnValues = true) + assert(result3.size > 0) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(diskStore.contains("b1")) + assert(!diskStore.contains("b2")) + assert(!diskStore.contains("b3")) + memoryStore.remove("b3") + store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Unroll huge block with not enough space. This should fail and drop the new block to disk + // directly in addition to kicking out b2 in the process. Memory store should contain only + // b3, while disk store should contain b1, b2 and b4. + val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk, returnValues = true) + assert(result4.size > 0) + assert(result4.data.isRight) // unroll returned bytes from disk + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(diskStore.contains("b1")) + assert(diskStore.contains("b2")) + assert(!diskStore.contains("b3")) + assert(diskStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + } + + test("multiple unrolls by the same thread") { + store = makeBlockManager(12000) + val memOnly = StorageLevel.MEMORY_ONLY + val memoryStore = store.memoryStore + val smallList = List.fill(40)(new Array[Byte](100)) + def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // All unroll memory used is released because unrollSafely returned an array + memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) + assert(memoryStore.currentUnrollMemoryForThisThread === 0) + + // Unroll memory is not released because unrollSafely returned an iterator + // that still depends on the underlying vector used in the process + memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + assert(unrollMemoryAfterB3 > 0) + + // The unroll memory owned by this thread builds on top of its value after the previous unrolls + memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) + + // ... but only to a certain extent (until we run out of free space to grant new unroll memory) + memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index fa43b66c6cb5a..b52f81877d557 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -47,11 +47,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc } listener.completedStages.size should be (5) - listener.completedStages.filter(_.stageId == 50).size should be (1) - listener.completedStages.filter(_.stageId == 49).size should be (1) - listener.completedStages.filter(_.stageId == 48).size should be (1) - listener.completedStages.filter(_.stageId == 47).size should be (1) - listener.completedStages.filter(_.stageId == 46).size should be (1) + listener.completedStages.count(_.stageId == 50) should be (1) + listener.completedStages.count(_.stageId == 49) should be (1) + listener.completedStages.count(_.stageId == 48) should be (1) + listener.completedStages.count(_.stageId == 47) should be (1) + listener.completedStages.count(_.stageId == 46) should be (1) } test("test executor id to summary") { @@ -59,20 +59,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() - - // nothing in it - assert(listener.stageIdToExecutorSummaries.size == 0) + assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0, null, null, 0, null) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) - .shuffleRead == 1000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) + .shuffleRead === 1000) // finish a task with unknown executor-id, nothing should happen taskInfo = @@ -80,27 +78,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.size == 1) + assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration - shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) - .shuffleRead == 2000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) + .shuffleRead === 2000) // finish this task, should get updated duration - shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) - .shuffleRead == 1000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) + .shuffleRead === 1000) } test("test task success vs failure counting for different task end reasons") { @@ -121,13 +115,17 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc TaskKilled, ExecutorLostFailure, UnknownReason) + var failCount = 0 for (reason <- taskFailedReasons) { listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) - assert(listener.stageIdToTasksComplete.get(task.stageId) === None) + failCount += 1 + assert(listener.stageIdToData(task.stageId).numCompleteTasks === 0) + assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } // Make sure we count success as success. listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToTasksComplete.get(task.stageId) === Some(1)) + assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) + assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index ca37d707b06ca..d2bee448d4d3b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -135,12 +135,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - assert(appender.isInstanceOf[ExpectedAppender]) + //assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy - rollingPolicy.isInstanceOf[ExpectedRollingPolicy] val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis } else { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 058d31453081a..9305b6d9738e1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -314,7 +314,6 @@ class JsonProtocolSuite extends FunSuite { private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime) - assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched) assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched) assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched) assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime) @@ -513,12 +512,11 @@ class JsonProtocolSuite extends FunSuite { } else { val sr = new ShuffleReadMetrics sr.shuffleFinishTime = b + c - sr.totalBlocksFetched = e + f sr.remoteBytesRead = b + d sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.shuffleReadMetrics = Some(sr) + t.updateShuffleReadMetrics(sr) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d @@ -584,7 +582,6 @@ class JsonProtocolSuite extends FunSuite { | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, | "Shuffle Read Metrics":{ | "Shuffle Finish Time":900, - | "Total Blocks Fetched":1500, | "Remote Blocks Fetched":800, | "Local Blocks Fetched":700, | "Fetch Wait Time":900, diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala deleted file mode 100644 index 93f0c6a8e6408..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.util.Random - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass -import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap} - -class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll { - val NORMAL_ERROR = 0.20 - val HIGH_ERROR = 0.30 - - test("fixed size insertions") { - testWith[Int, Long](10000, i => (i, i.toLong)) - testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) - testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass())) - } - - test("variable size insertions") { - val rand = new Random(123456789) - def randString(minLen: Int, maxLen: Int): String = { - "a" * (rand.nextInt(maxLen - minLen) + minLen) - } - testWith[Int, String](10000, i => (i, randString(0, 10))) - testWith[Int, String](10000, i => (i, randString(0, 100))) - testWith[Int, String](10000, i => (i, randString(90, 100))) - } - - test("updates") { - val rand = new Random(123456789) - def randString(minLen: Int, maxLen: Int): String = { - "a" * (rand.nextInt(maxLen - minLen) + minLen) - } - testWith[String, Int](10000, i => (randString(0, 10000), i)) - } - - def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) { - val map = new SizeTrackingAppendOnlyMap[K, V]() - for (i <- 0 until numElements) { - val (k, v) = makeElement(i) - map(k) = v - expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) - } - } - - def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { - val betterEstimatedSize = SizeEstimator.estimate(obj) - assert(betterEstimatedSize * (1 - error) < estimatedSize, - s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") - assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, - s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") - } -} - -object SizeTrackingAppendOnlyMapSuite { - // Speed test, for reproducibility of results. - // These could be highly non-deterministic in general, however. - // Results: - // AppendOnlyMap: 31 ms - // SizeTracker: 54 ms - // SizeEstimator: 1500 ms - def main(args: Array[String]) { - val numElements = 100000 - - val baseTimes = for (i <- 0 until 10) yield time { - val map = new AppendOnlyMap[Int, LargeDummyClass]() - for (i <- 0 until numElements) { - map(i) = new LargeDummyClass() - } - } - - val sampledTimes = for (i <- 0 until 10) yield time { - val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]() - for (i <- 0 until numElements) { - map(i) = new LargeDummyClass() - map.estimateSize() - } - } - - val unsampledTimes = for (i <- 0 until 3) yield time { - val map = new AppendOnlyMap[Int, LargeDummyClass]() - for (i <- 0 until numElements) { - map(i) = new LargeDummyClass() - SizeEstimator.estimate(map) - } - } - - println("Base: " + baseTimes) - println("SizeTracker (sampled): " + sampledTimes) - println("SizeEstimator (unsampled): " + unsampledTimes) - } - - def time(f: => Unit): Long = { - val start = System.currentTimeMillis() - f - System.currentTimeMillis() - start - } - - private class LargeDummyClass { - val arr = new Array[Int](100) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 7006571ef0ef6..794a55d61750b 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.FunSuite /** * Tests org.apache.spark.util.Vector functionality */ +@deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { def verifyVector(vector: Vector, expectedLength: Int) = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index 52c7288e18b69..cb99d14b27af4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -170,10 +170,10 @@ class AppendOnlyMapSuite extends FunSuite { case e: IllegalStateException => fail() } - val it = map.destructiveSortedIterator(new Comparator[(String, String)] { - def compare(kv1: (String, String), kv2: (String, String)): Int = { - val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue - val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue + val it = map.destructiveSortedIterator(new Comparator[String] { + def compare(key1: String, key2: String): Int = { + val x = if (key1 != null) key1.toInt else Int.MinValue + val y = if (key2 != null) key2.toInt else Int.MinValue x.compareTo(y) } }) diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala new file mode 100644 index 0000000000000..6c956d93dc80d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + +class CompactBufferSuite extends FunSuite { + test("empty buffer") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + assert(b.size === 0) + assert(b.iterator.toList === Nil) + intercept[IndexOutOfBoundsException] { b(0) } + intercept[IndexOutOfBoundsException] { b(1) } + intercept[IndexOutOfBoundsException] { b(2) } + intercept[IndexOutOfBoundsException] { b(-1) } + } + + test("basic inserts") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + for (i <- 0 until 1000) { + b += i + assert(b.size === i + 1) + assert(b(i) === i) + } + assert(b.iterator.toList === (0 until 1000).toList) + assert(b.iterator.toList === (0 until 1000).toList) + assert(b.size === 1000) + } + + test("adding sequences") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + + // Add some simple lists and iterators + b ++= List(0) + assert(b.size === 1) + assert(b.iterator.toList === List(0)) + b ++= Iterator(1) + assert(b.size === 2) + assert(b.iterator.toList === List(0, 1)) + b ++= List(2) + assert(b.size === 3) + assert(b.iterator.toList === List(0, 1, 2)) + b ++= Iterator(3, 4, 5, 6, 7, 8, 9) + assert(b.size === 10) + assert(b.iterator.toList === (0 until 10).toList) + + // Add CompactBuffers + val b2 = new CompactBuffer[Int] + b2 ++= 0 until 10 + b ++= b2 + assert(b.iterator.toList === (1 to 2).flatMap(i => 0 until 10).toList) + b ++= b2 + assert(b.iterator.toList === (1 to 3).flatMap(i => 0 until 10).toList) + b ++= b2 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) + + // Add some small CompactBuffers as well + val b3 = new CompactBuffer[Int] + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) + b3 += 0 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0)) + b3 += 1 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1)) + b3 += 2 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1, 0, 1, 2)) + } + + test("adding the same buffer to itself") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + b += 1 + assert(b.toList === List(1)) + for (j <- 1 until 8) { + b ++= b + assert(b.size === (1 << j)) + assert(b.iterator.toList === (1 to (1 << j)).map(i => 1).toList) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 428822949c085..0b7ad184a46d2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -63,12 +63,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) - map.insert(1, 10) - map.insert(2, 20) - map.insert(3, 30) - map.insert(1, 100) - map.insert(2, 200) - map.insert(1, 1000) + map.insertAll(Seq( + (1, 10), + (2, 20), + (3, 30), + (1, 100), + (2, 200), + (1, 1000))) val it = map.iterator assert(it.hasNext) val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) @@ -282,7 +283,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { assert(w1.hashCode === w2.hashCode) } - (1 to 100000).map(_.toString).foreach { i => map.insert(i, i) } + map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i))) collisionPairs.foreach { case (w1, w2) => map.insert(w1, w2) map.insert(w2, w1) @@ -355,7 +356,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]]( createCombiner, mergeValue, mergeCombiners) - (1 to 100000).foreach { i => map.insert(i, i) } + map.insertAll((1 to 100000).iterator.map(i => (i, i))) map.insert(null.asInstanceOf[Int], 1) map.insert(1, null.asInstanceOf[Int]) map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int]) diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala new file mode 100644 index 0000000000000..1f33967249654 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.util.SizeEstimator + +class SizeTrackerSuite extends FunSuite { + val NORMAL_ERROR = 0.20 + val HIGH_ERROR = 0.30 + + import SizeTrackerSuite._ + + test("vector fixed size insertions") { + testVector[Long](10000, i => i.toLong) + testVector[(Long, Long)](10000, i => (i.toLong, i.toLong)) + testVector[LargeDummyClass](10000, i => new LargeDummyClass) + } + + test("vector variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testVector[String](10000, i => randString(0, 10)) + testVector[String](10000, i => randString(0, 100)) + testVector[String](10000, i => randString(90, 100)) + } + + test("map fixed size insertions") { + testMap[Int, Long](10000, i => (i, i.toLong)) + testMap[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) + testMap[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass)) + } + + test("map variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testMap[Int, String](10000, i => (i, randString(0, 10))) + testMap[Int, String](10000, i => (i, randString(0, 100))) + testMap[Int, String](10000, i => (i, randString(90, 100))) + } + + test("map updates") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testMap[String, Int](10000, i => (randString(0, 10000), i)) + } + + def testVector[T: ClassTag](numElements: Int, makeElement: Int => T) { + val vector = new SizeTrackingVector[T] + for (i <- 0 until numElements) { + val item = makeElement(i) + vector += item + expectWithinError(vector, vector.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def testMap[K, V](numElements: Int, makeElement: (Int) => (K, V)) { + val map = new SizeTrackingAppendOnlyMap[K, V] + for (i <- 0 until numElements) { + val (k, v) = makeElement(i) + map(k) = v + expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { + val betterEstimatedSize = SizeEstimator.estimate(obj) + assert(betterEstimatedSize * (1 - error) < estimatedSize, + s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") + assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, + s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") + } +} + +private object SizeTrackerSuite { + + /** + * Run speed tests for size tracking collections. + */ + def main(args: Array[String]): Unit = { + if (args.size < 1) { + println("Usage: SizeTrackerSuite [num elements]") + System.exit(1) + } + val numElements = args(0).toInt + vectorSpeedTest(numElements) + mapSpeedTest(numElements) + } + + /** + * Speed test for SizeTrackingVector. + * + * Results for 100000 elements (possibly non-deterministic): + * PrimitiveVector 15 ms + * SizeTracker 51 ms + * SizeEstimator 2000 ms + */ + def vectorSpeedTest(numElements: Int): Unit = { + val baseTimes = for (i <- 0 until 10) yield time { + val vector = new PrimitiveVector[LargeDummyClass] + for (i <- 0 until numElements) { + vector += new LargeDummyClass + } + } + val sampledTimes = for (i <- 0 until 10) yield time { + val vector = new SizeTrackingVector[LargeDummyClass] + for (i <- 0 until numElements) { + vector += new LargeDummyClass + vector.estimateSize() + } + } + val unsampledTimes = for (i <- 0 until 3) yield time { + val vector = new PrimitiveVector[LargeDummyClass] + for (i <- 0 until numElements) { + vector += new LargeDummyClass + SizeEstimator.estimate(vector) + } + } + printSpeedTestResult("SizeTrackingVector", baseTimes, sampledTimes, unsampledTimes) + } + + /** + * Speed test for SizeTrackingAppendOnlyMap. + * + * Results for 100000 elements (possibly non-deterministic): + * AppendOnlyMap 30 ms + * SizeTracker 41 ms + * SizeEstimator 1666 ms + */ + def mapSpeedTest(numElements: Int): Unit = { + val baseTimes = for (i <- 0 until 10) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass] + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass + } + } + val sampledTimes = for (i <- 0 until 10) yield time { + val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass] + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass + map.estimateSize() + } + } + val unsampledTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass] + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass + SizeEstimator.estimate(map) + } + } + printSpeedTestResult("SizeTrackingAppendOnlyMap", baseTimes, sampledTimes, unsampledTimes) + } + + def printSpeedTestResult( + testName: String, + baseTimes: Seq[Long], + sampledTimes: Seq[Long], + unsampledTimes: Seq[Long]): Unit = { + println(s"Average times for $testName (ms):") + println(" Base - " + averageTime(baseTimes)) + println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) + println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) + println() + } + + def time(f: => Unit): Long = { + val start = System.currentTimeMillis() + f + System.currentTimeMillis() - start + } + + def averageTime(v: Seq[Long]): Long = { + v.sum / v.size + } + + private class LargeDummyClass { + val arr = new Array[Int](100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala new file mode 100644 index 0000000000000..6fe1079c2719a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.lang.{Float => JFloat} +import java.util.{Arrays, Comparator} + +import org.scalatest.FunSuite + +import org.apache.spark.util.random.XORShiftRandom + +class SorterSuite extends FunSuite { + + test("equivalent to Arrays.sort") { + val rand = new XORShiftRandom(123) + val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() } + val data1 = data0.clone() + + Arrays.sort(data0) + new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int) + + data0.zip(data1).foreach { case (x, y) => assert(x === y) } + } + + test("KVArraySorter") { + val rand = new XORShiftRandom(456) + + // Construct an array of keys (to Java sort) and an array where the keys and values + // alternate. Keys are random doubles, values are ordinals from 0 to length. + val keys = Array.tabulate[Double](5000) { i => rand.nextDouble() } + val keyValueArray = Array.tabulate[Number](10000) { i => + if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + } + + // Map from generated keys to values, to verify correctness later + val kvMap = + keyValueArray.grouped(2).map { case Array(k, v) => k.doubleValue() -> v.intValue() }.toMap + + Arrays.sort(keys) + new Sorter(new KVArraySortDataFormat[Double, Number]) + .sort(keyValueArray, 0, keys.length, Ordering.Double) + + keys.zipWithIndex.foreach { case (k, i) => + assert(k === keyValueArray(2 * i)) + assert(kvMap(k) === keyValueArray(2 * i + 1)) + } + } + + /** + * This provides a simple benchmark for comparing the Sorter with Java internal sorting. + * Ideally these would be executed one at a time, each in their own JVM, so their listing + * here is mainly to have the code. + * + * The goal of this code is to sort an array of key-value pairs, where the array physically + * has the keys and values alternating. The basic Java sorts work only on the keys, so the + * real Java solution is to make Tuple2s to store the keys and values and sort an array of + * those, while the Sorter approach can work directly on the input data format. + * + * Note that the Java implementation varies tremendously between Java 6 and Java 7, when + * the Java sort changed from merge sort to Timsort. + */ + ignore("Sorter benchmark") { + + /** Runs an experiment several times. */ + def runExperiment(name: String)(f: => Unit): Unit = { + val firstTry = org.apache.spark.util.Utils.timeIt(1)(f) + System.gc() + + var i = 0 + var next10: Long = 0 + while (i < 10) { + val time = org.apache.spark.util.Utils.timeIt(1)(f) + next10 += time + println(s"$name: Took $time ms") + i += 1 + } + + println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + } + + val numElements = 25000000 // 25 mil + val rand = new XORShiftRandom(123) + + val keys = Array.tabulate[JFloat](numElements) { i => + new JFloat(rand.nextFloat()) + } + + // Test our key-value pairs where each element is a Tuple2[Float, Integer) + val kvTupleArray = Array.tabulate[AnyRef](numElements) { i => + (keys(i / 2): Float, i / 2: Int) + } + runExperiment("Tuple-sort using Arrays.sort()") { + Arrays.sort(kvTupleArray, new Comparator[AnyRef] { + override def compare(x: AnyRef, y: AnyRef): Int = + Ordering.Float.compare(x.asInstanceOf[(Float, _)]._1, y.asInstanceOf[(Float, _)]._1) + }) + } + + // Test our Sorter where each element alternates between Float and Integer, non-primitive + val keyValueArray = Array.tabulate[AnyRef](numElements * 2) { i => + if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + } + val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef]) + runExperiment("KV-sort using Sorter") { + sorter.sort(keyValueArray, 0, keys.length, new Comparator[JFloat] { + override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) + }) + } + + // Test non-primitive sort on float array + runExperiment("Java Arrays.sort()") { + Arrays.sort(keys, new Comparator[JFloat] { + override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) + }) + } + + // Test primitive sort on float array + val primitiveKeys = Array.tabulate[Float](numElements) { i => rand.nextFloat() } + runExperiment("Java Arrays.sort() on primitive keys") { + Arrays.sort(primitiveKeys) + } + } +} + + +/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */ +class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] { + override protected def getKey(data: Array[Int], pos: Int): Int = { + data(pos) + } + + override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override protected def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) { + dst(dstPos) = src(srcPos) + } + + /** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */ + override protected def copyRange(src: Array[Int], srcPos: Int, + dst: Array[Int], dstPos: Int, length: Int) { + System.arraycopy(src, srcPos, dst, dstPos, length) + } + + /** Allocates a new structure that can hold up to 'length' elements. */ + override protected def allocate(length: Int): Array[Int] = { + new Array[Int](length) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index accfe2e9b7f2a..73a9d029b0248 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -17,11 +17,32 @@ package org.apache.spark.util.random +import scala.util.Random + import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.scalatest.FunSuite class SamplingUtilsSuite extends FunSuite { + test("reservoirSampleAndCount") { + val input = Seq.fill(100)(Random.nextInt()) + + // input size < k + val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150) + assert(count1 === 100) + assert(input === sample1.toSeq) + + // input size == k + val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100) + assert(count2 === 100) + assert(input === sample2.toSeq) + + // input size > k + val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10) + assert(count3 === 100) + assert(sample3.length === 10) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 49bf78f60763a..38830103d1e8d 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -95,7 +95,7 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh $FLAGS --name $NAME --tgz + ./make-distribution.sh --name $NAME --tgz $FLAGS cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME @@ -111,9 +111,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "--with-hive --hadoop 1.0.4" -make_binary_release "cdh4" "--with-hive --hadoop 2.0.0-mr1-cdh4.2.0" -make_binary_release "hadoop2" "--with-hive --with-yarn --hadoop 2.2.0" +make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop2" \ + "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py new file mode 100755 index 0000000000000..8051080117062 --- /dev/null +++ b/dev/github_jira_sync.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Utility for updating JIRA's with information about Github pull requests + +import json +import os +import re +import sys +import urllib2 + +try: + import jira.client +except ImportError: + print "This tool requires the jira-python library" + print "Install using 'sudo pip install jira-python'" + sys.exit(-1) + +# User facing configs +GITHUB_API_BASE = os.environ.get("GITHUB_API_BASE", "https://api.github.com/repos/apache/spark") +JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "apachespark") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "XXX") +# Maximum number of updates to perform in one run +MAX_UPDATES = int(os.environ.get("MAX_UPDATES", "100000")) +# Cut-off for oldest PR on which to comment. Useful for avoiding +# "notification overload" when running for the first time. +MIN_COMMENT_PR = int(os.environ.get("MIN_COMMENT_PR", "1496")) + +# File used as an opitimization to store maximum previously seen PR +# Used mostly because accessing ASF JIRA is slow, so we want to avoid checking +# the state of JIRA's that are tied to PR's we've already looked at. +MAX_FILE = ".github-jira-max" + +def get_url(url): + try: + return urllib2.urlopen(url) + except urllib2.HTTPError as e: + print "Unable to fetch URL, exiting: %s" % url + sys.exit(-1) + +def get_json(urllib_response): + return json.load(urllib_response) + +# Return a list of (JIRA id, JSON dict) tuples: +# e.g. [('SPARK-1234', {.. json ..}), ('SPARK-5687', {.. json ..})} +def get_jira_prs(): + result = [] + has_next_page = True + page_num = 0 + while has_next_page: + page = get_url(GITHUB_API_BASE + "/pulls?page=%s&per_page=100" % page_num) + page_json = get_json(page) + + for pull in page_json: + jiras = re.findall("SPARK-[0-9]{4,5}", pull['title']) + for jira in jiras: + result = result + [(jira, pull)] + + # Check if there is another page + link_header = filter(lambda k: k.startswith("Link"), page.info().headers)[0] + if not "next"in link_header: + has_next_page = False + else: + page_num = page_num + 1 + return result + +def set_max_pr(max_val): + f = open(MAX_FILE, 'w') + f.write("%s" % max_val) + f.close() + print "Writing largest PR number seen: %s" % max_val + +def get_max_pr(): + if os.path.exists(MAX_FILE): + result = int(open(MAX_FILE, 'r').read()) + print "Read largest PR number previously seen: %s" % result + return result + else: + return 0 + +jira_client = jira.client.JIRA({'server': JIRA_API_BASE}, + basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) + +jira_prs = get_jira_prs() + +previous_max = get_max_pr() +print "Retrieved %s JIRA PR's from Github" % len(jira_prs) +jira_prs = [(k, v) for k, v in jira_prs if int(v['number']) > previous_max] +print "%s PR's remain after excluding visted ones" % len(jira_prs) + +num_updates = 0 +considered = [] +for issue, pr in sorted(jira_prs, key=lambda (k, v): int(v['number'])): + if num_updates >= MAX_UPDATES: + break + pr_num = int(pr['number']) + + print "Checking issue %s" % issue + considered = considered + [pr_num] + + url = pr['html_url'] + title = "[Github] Pull Request #%s (%s)" % (pr['number'], pr['user']['login']) + try: + existing_links = map(lambda l: l.raw['object']['url'], jira_client.remote_links(issue)) + except: + print "Failure reading JIRA %s (does it exist?)" % issue + print sys.exc_info()[0] + continue + + if url in existing_links: + continue + + icon = {"title": "Pull request #%s" % pr['number'], + "url16x16": "https://assets-cdn.github.com/favicon.ico"} + destination = {"title": title, "url": url, "icon": icon} + # For all possible fields see: + # https://developer.atlassian.com/display/JIRADEV/Fields+in+Remote+Issue+Links + # application = {"name": "Github pull requests", "type": "org.apache.spark.jira.github"} + jira_client.add_remote_link(issue, destination) + + comment = "User '%s' has created a pull request for this issue:" % pr['user']['login'] + comment = comment + ("\n%s" % pr['html_url']) + if pr_num >= MIN_COMMENT_PR: + jira_client.add_comment(issue, comment) + + print "Added link %s <-> PR #%s" % (issue, pr['number']) + num_updates = num_updates + 1 + +if len(considered) > 0: + set_max_pr(max(considered)) diff --git a/dev/run-tests b/dev/run-tests index edd17b53b3d8c..51e4def0f835a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,8 +21,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR -export SPARK_HADOOP_VERSION=2.3.0 -export SPARK_YARN=true +export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" # Remove work directory rm -rf ./work @@ -66,8 +65,8 @@ echo "=========================================================================" # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SPARK_HIVE=true sbt/sbt clean package assembly/assembly test | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ + assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" else echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" diff --git a/dev/scalastyle b/dev/scalastyle index 0e8fd5cc8d64c..a02d06912f238 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,12 +17,12 @@ # limitations under the License. # -echo -e "q\n" | SPARK_HIVE=true sbt/sbt scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt # Check style with YARN alpha built too -echo -e "q\n" | SPARK_HADOOP_VERSION=0.23.9 SPARK_YARN=true sbt/sbt yarn-alpha/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | SPARK_HADOOP_VERSION=2.2.0 SPARK_YARN=true sbt/sbt yarn/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") diff --git a/docs/configuration.md b/docs/configuration.md index 07aa4c035446b..2e6c85cc2bcca 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -42,13 +42,15 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My fancy app" --master local[4] myApp.jar +./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false + --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) tool support two ways to load configurations dynamically. The first are command line options, -such as `--master`, as shown above. Running `./bin/spark-submit --help` will show the entire list -of options. +such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` +flag, but uses special flags for properties that play a part in launching the Spark application. +Running `./bin/spark-submit --help` will show the entire list of these options. `bin/spark-submit` will also read configuration options from `conf/spark-defaults.conf`, in which each line consists of a key and a value separated by whitespace. For example: @@ -195,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful Spark's dependencies and user dependencies. It is currently an experimental feature. + + spark.python.worker.memory + 512m + + Amount of memory to use per python worker process during aggregation, in the same + format as JVM memory strings (e.g. 512m, 2g). If the memory + used during aggregation goes above this amount, it will spill the data into disks. + + #### Shuffle Behavior @@ -228,7 +239,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.memoryFraction - 0.3 + 0.2 Fraction of Java heap to use for aggregation and cogroups during shuffles, if spark.shuffle.spill is true. At any given time, the collective size of @@ -336,13 +347,12 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - org.apache.spark.io.
LZFCompressionCodec + org.apache.spark.io.
SnappyCompressionCodec The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides two codecs: org.apache.spark.io.LZFCompressionCodec - and org.apache.spark.io.SnappyCompressionCodec. Of these two choices, - Snappy offers faster compression and decompression, while LZF offers a better compression - ratio. + By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + org.apache.spark.io.LZFCompressionCodec, + and org.apache.spark.io.SnappyCompressionCodec. @@ -350,7 +360,15 @@ Apart from these, the following properties are also available, and may be useful 32768 Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec - is used. + is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. + + + + spark.io.compression.lz4.block.size + 32768 + + Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -362,13 +380,13 @@ Apart from these, the following properties are also available, and may be useful spark.serializer.objectStreamReset - 10000 + 100 When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches objects to prevent writing redundant data, however that stops garbage collection of those objects. By calling 'reset' you flush that info from the serializer, and allow old objects to be collected. To turn off this periodic reset set it to a value <= 0. - By default it will reset the serializer every 10,000 objects. + By default it will reset the serializer every 100 objects. @@ -381,6 +399,17 @@ Apart from these, the following properties are also available, and may be useful case. + + spark.kryo.registrationRequired + false + + Whether to require registration with Kryo. If set to 'true', Kryo will throw an exception + if an unregistered class is serialized. If set to false (the default), Kryo will write + unregistered class names along with each object. Writing class names can cause + significant performance overhead, so enabling this option can enforce strictly that a + user has not omitted classes from registration. + + spark.kryoserializer.buffer.mb 2 @@ -412,7 +441,7 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.factory - org.apache.spark.broadcast.
HttpBroadcastFactory + org.apache.spark.broadcast.
TorrentBroadcastFactory Which broadcast implementation to use. @@ -451,6 +480,15 @@ Apart from these, the following properties are also available, and may be useful increase it if you configure your own old generation size. + + spark.storage.unrollFraction + 0.2 + + Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory. + This is dynamically allocated by dropping existing blocks when there is not enough free + storage space to unroll the new block in its entirety. + + spark.tachyonStore.baseDir System.getProperty("java.io.tmpdir") @@ -490,9 +528,9 @@ Apart from these, the following properties are also available, and may be useful spark.hadoop.validateOutputSpecs true - If set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with + If set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. @@ -854,7 +892,7 @@ Apart from these, the following properties are also available, and may be useful #### Cluster Managers -Each cluster manager in Spark has additional configuration options. Configurations +Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: * [YARN](running-on-yarn.html#configuration) diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index 32403bc6957a2..ab1023b8f1842 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -48,9 +48,9 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors. -In SBT, the equivalent can be achieved by setting the SPARK_HADOOP_VERSION flag: +In SBT, the equivalent can be achieved by setting the the `hadoop.version` property: - SPARK_HADOOP_VERSION=1.0.4 sbt/sbt assembly + sbt/sbt -Dhadoop.version=1.0.4 assembly # Linking Applications to the Hadoop Version diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index c76ac010d3f81..561de48910132 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -69,7 +69,54 @@ println("Within Set Sum of Squared Errors = " + WSSSE) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. +calling `.rdd()` on your `JavaRDD` object. A standalone application example +that is equivalent to the provided example in Scala is given bellow: + +{% highlight java %} +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.KMeans; +import org.apache.spark.mllib.clustering.KMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class KMeansExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("K-means Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse data + String path = "data/mllib/kmeans_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + + // Cluster the data into two classes using KMeans + int numClusters = 2; + int numIterations = 20; + KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); + + // Evaluate clustering by computing Within Set Sum of Squared Errors + double WSSSE = clusters.computeCost(parsedData.rdd()); + System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + } +} +{% endhighlight %} + +In order to run the above standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency.
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 5cd71738722a9..0d28b5f7c89b3 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -99,7 +99,85 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. +calling `.rdd()` on your `JavaRDD` object. A standalone application example +that is equivalent to the provided example in Scala is given bellow: + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; + +public class CollaborativeFiltering { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 20; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + } +} +{% endhighlight %} + +In order to run the above standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency.
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cd768599e529..9cbd880897578 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -77,15 +77,17 @@ bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for -binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the +For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For +binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the categorical feature values by the proportion of labels falling in one of the two classes (see Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. +and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value +is used for ordering. ### Stopping rule diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index e3608075fbb13..8e434998c15ea 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -57,10 +57,57 @@ val U: RowMatrix = svd.U // The U factor is a RowMatrix. val s: Vector = svd.s // The singular values are stored in a local dense vector. val V: Matrix = svd.V // The V factor is a local dense matrix. {% endhighlight %} + +Same code applies to `IndexedRowMatrix`. +The only difference that the `U` matrix becomes an `IndexedRowMatrix`.
+
+In order to run the following standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency. + +{% highlight java %} +import java.util.LinkedList; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.SingularValueDecomposition; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class SVD { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVD Example"); + SparkContext sc = new SparkContext(conf); + + double[][] array = ... + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 4 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(4, true, 1.0E-9d); + RowMatrix U = svd.U(); + Vector s = svd.s(); + Matrix V = svd.V(); + } +} +{% endhighlight %} Same code applies to `IndexedRowMatrix`. The only difference that the `U` matrix becomes an `IndexedRowMatrix`.
+
## Principal component analysis (PCA) @@ -91,4 +138,51 @@ val pc: Matrix = mat.computePrincipalComponents(10) // Principal components are val projected: RowMatrix = mat.multiply(pc) {% endhighlight %} + +
+ +The following code demonstrates how to compute principal components on a tall-and-skinny `RowMatrix` +and use them to project the vectors into a low-dimensional space. +The number of columns should be small, e.g, less than 1000. + +{% highlight java %} +import java.util.LinkedList; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class PCA { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("PCA Example"); + SparkContext sc = new SparkContext(conf); + + double[][] array = ... + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 principal components. + Matrix pc = mat.computePrincipalComponents(3); + RowMatrix projected = mat.multiply(pc); + } +} +{% endhighlight %} + +In order to run the above standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency. +
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index b4d22e0df5a85..254201147edc1 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -151,10 +151,10 @@ L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). Logistic regression algorithm outputs a logistic regression model, which makes predictions by applying the logistic function `\[ -\mathrm{logit}(z) = \frac{1}{1 + e^{-z}} +\mathrm{f}(z) = \frac{1}{1 + e^{-z}} \]` -$\wv^T \x$. -By default, if $\mathrm{logit}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. +where $z = \wv^T \x$. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. For the same reason mentioned above, quite often in practice, this default threshold is not a good choice. The threshold should be determined via model evaluation. @@ -242,7 +242,86 @@ Similarly, you can use replace `SVMWithSGD` by All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. +calling `.rdd()` on your `JavaRDD` object. A standalone application example +that is equivalent to the provided example in Scala is given bellow: + +{% highlight java %} +import java.util.Random; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.*; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class SVMClassifier { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD training = data.sample(false, 0.6, 11L); + training.cache(); + JavaRDD test = data.subtract(training); + + // Run training algorithm to build the model. + int numIterations = 100; + final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); + double auROC = metrics.areaUnderROC(); + + System.out.println("Area under ROC = " + auROC); + } +} +{% endhighlight %} + +The `SVMWithSGD.train()` method by default performs L2 regularization with the +regularization parameter set to 1.0. If we want to configure this algorithm, we +can customize `SVMWithSGD` further by creating a new object directly and +calling setter methods. All other MLlib algorithms support customization in +this way as well. For example, the following code produces an L1 regularized +variant of SVMs with regularization parameter set to 0.1, and runs the training +algorithm for 200 iterations. + +{% highlight java %} +import org.apache.spark.mllib.optimization.L1Updater; + +SVMWithSGD svmAlg = new SVMWithSGD(); +svmAlg.optimizer() + .setNumIterations(200) + .setRegParam(0.1) + .setUpdater(new L1Updater()); +final SVMModel modelL1 = svmAlg.run(training.rdd()); +{% endhighlight %} + +In order to run the above standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency.
@@ -338,7 +417,72 @@ and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.Lass All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. +calling `.rdd()` on your `JavaRDD` object. The corresponding Java example to +the Scala snippet provided, is presented bellow: + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/ridge-data/lpsa.data"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) + v[i] = Double.parseDouble(features[i]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + JavaRDD MSE = new JavaDoubleRDD(valuesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + return Math.pow(pair._1() - pair._2(), 2.0); + } + } + ).rdd()).mean(); + System.out.println("training Mean Squared Error = " + MSE); + } +} +{% endhighlight %} + +In order to run the above standalone application using Spark framework make +sure that you follow the instructions provided at section [Standalone +Applications](quick-start.html) of the quick-start guide. What is more, you +should include to your build file *spark-mllib* as a dependency.
diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 651958c7812f2..26ce5f3c501ff 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -207,6 +207,10 @@ the loss computed for every iteration. Here is an example to train binary logistic regression with L2 regularization using L-BFGS optimizer. + +
+ +
{% highlight scala %} import org.apache.spark.SparkContext import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -263,7 +267,97 @@ println("Loss of each step in training process") loss.foreach(println) println("Area under ROC = " + auROC) {% endhighlight %} - +
+ +
+{% highlight java %} +import java.util.Arrays; +import java.util.Random; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class LBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + } +} +{% endhighlight %} +
+
#### Developer's note Since the Hessian is constructed approximately from previous gradient evaluations, the objective function can not be changed during the optimization process. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index b09d6347cd1b2..90c69713019f2 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -739,7 +739,7 @@ def doStuff(self, rdd): While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Scala, these operations are automatically available on RDDs containing @@ -773,7 +773,7 @@ documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#ha While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Java, key-value pairs are represented using the @@ -810,7 +810,7 @@ documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#ha While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Python, these operations work on RDDs containing built-in Python tuples such as `(1, 2)`. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f5c0f7cef83d2..ad8b6c0e51a78 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -156,6 +156,20 @@ SPARK_MASTER_OPTS supports the following system properties: + + + + + + + + + + diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 522c83884ef42..38728534a46e0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -474,7 +474,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run '`SPARK_HIVE=true sbt/sbt assembly/assembly`' (or use `-Phive` for maven). +In order to use Hive you must first run '`sbt/sbt -Phive assembly/assembly`' (or use `-Phive` for maven). This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to acccess data stored in Hive. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e05883072bfa8..45b70b1a5457a 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -33,6 +33,7 @@ dependencies, and can support different cluster managers and deploy modes that S --class --master \ --deploy-mode \ + --conf = \ ... # other options \ [application-arguments] @@ -43,6 +44,7 @@ Some of the commonly used options are: * `--class`: The entry point for your application (e.g. `org.apache.spark.examples.SparkPi`) * `--master`: The [master URL](#master-urls) for the cluster (e.g. `spark://23.195.26.187:7077`) * `--deploy-mode`: Whether to deploy your driver on the worker nodes (`cluster`) or locally as an external client (`client`) (default: `client`)* +* `--conf`: Arbitrary Spark configuration property in key=value format. For values that contain spaces wrap "key=value" in quotes (as shown). * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 44775ea479ece..02cfe4ec39c7d 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -240,7 +240,10 @@ def get_spark_ami(opts): "r3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm" + "r3.8xlarge": "hvm", + "t2.micro": "hvm", + "t2.small": "hvm", + "t2.medium": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1a7c4c51f48cd..c862650b0aa1d 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -16,6 +16,9 @@ # """ +This is an example implementation of ALS for learning how to use Spark. Please refer to +ALS in pyspark.mllib.recommendation for more conventional use. + This example requires numpy (http://www.numpy.org/) """ from os.path import realpath @@ -49,9 +52,15 @@ def update(i, vec, mat, ratings): if __name__ == "__main__": + """ Usage: als [M] [U] [F] [iterations] [slices]" """ + + print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an + example. Please use the ALS method found in pyspark.mllib.recommendation for more + conventional use.""" + sc = SparkContext(appName="PythonALS") M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 U = int(sys.argv[2]) if len(sys.argv) > 2 else 500 diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 988fc45baf3bc..036bdf4c4f999 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -45,9 +45,15 @@ def closestPoint(p, centers): if __name__ == "__main__": + if len(sys.argv) != 4: print >> sys.stderr, "Usage: kmeans " exit(-1) + + print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given + as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on + how to use MLlib's KMeans implementation.""" + sc = SparkContext(appName="PythonKMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector).cache() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 6c33deabfd6ea..8456b272f9c05 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -47,9 +47,15 @@ def readPointBatch(iterator): return [matrix] if __name__ == "__main__": + if len(sys.argv) != 3: print >> sys.stderr, "Usage: logistic_regression " exit(-1) + + print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is + given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py + to see how MLlib's implementation is used.""" + sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() iterations = int(sys.argv[2]) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 658f73d96a86a..1f576319b3ca8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -25,6 +25,9 @@ import cern.jet.math._ /** * Alternating least squares matrix factorization. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.recommendation.ALS */ object LocalALS { // Parameters set through command line arguments @@ -107,7 +110,16 @@ object LocalALS { solved2D.viewColumn(0) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of ALS and is given as an example! + |Please use the ALS method found in org.apache.spark.mllib.recommendation + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + args match { case Array(m, u, f, iters) => { M = m.toInt @@ -120,6 +132,9 @@ object LocalALS { System.exit(1) } } + + showWarning() + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val R = generateR() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 0ef3001ca4ccd..931faac5463c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -21,6 +21,12 @@ import java.util.Random import breeze.linalg.{Vector, DenseVector} +/** + * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.classification.LogisticRegression + */ object LocalFileLR { val D = 10 // Numer of dimensions val rand = new Random(42) @@ -32,7 +38,18 @@ object LocalFileLR { DataPoint(new DenseVector(nums.slice(1, D + 1)), nums(0)) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + + showWarning() + val lines = scala.io.Source.fromFile(args(0)).getLines().toArray val points = lines.map(parsePoint _) val ITERATIONS = args(1).toInt diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index e33a1b336d163..17624c20cff3d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -28,6 +28,9 @@ import org.apache.spark.SparkContext._ /** * K-means clustering. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.clustering.KMeans */ object LocalKMeans { val N = 1000 @@ -61,7 +64,18 @@ object LocalKMeans { bestIndex } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of KMeans Clustering and is given as an example! + |Please use the KMeans method found in org.apache.spark.mllib.clustering + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + + showWarning() + val data = generateData var points = new HashSet[Vector[Double]] var kPoints = new HashMap[Int, Vector[Double]] diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 385b48089d572..2d75b9d2590f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -23,6 +23,9 @@ import breeze.linalg.{Vector, DenseVector} /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.classification.LogisticRegression */ object LocalLR { val N = 10000 // Number of data points @@ -42,9 +45,19 @@ object LocalLR { Array.tabulate(N)(generatePoint) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { - val data = generateData + showWarning() + + val data = generateData // Initialize w to a random value var w = DenseVector.fill(D){2 * rand.nextDouble - 1} println("Initial w: " + w) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 5cbc966bf06ca..fde8ffeedf8b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -27,6 +27,9 @@ import org.apache.spark._ /** * Alternating least squares matrix factorization. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.recommendation.ALS */ object SparkALS { // Parameters set through command line arguments @@ -87,7 +90,16 @@ object SparkALS { solved2D.viewColumn(0) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of ALS and is given as an example! + |Please use the ALS method found in org.apache.spark.mllib.recommendation + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + var slices = 0 val options = (0 to 4).map(i => if (i < args.length) Some(args(i)) else None) @@ -103,7 +115,11 @@ object SparkALS { System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") System.exit(1) } + + showWarning() + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + val sparkConf = new SparkConf().setAppName("SparkALS") val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 4906a696e90a7..d583cf421ed23 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -30,6 +30,9 @@ import org.apache.spark.scheduler.InputFormatInfo /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.classification.LogisticRegression */ object SparkHdfsLR { val D = 10 // Numer of dimensions @@ -48,12 +51,23 @@ object SparkHdfsLR { DataPoint(new DenseVector(x), y) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + if (args.length < 2) { System.err.println("Usage: SparkHdfsLR ") System.exit(1) } + showWarning() + val sparkConf = new SparkConf().setAppName("SparkHdfsLR") val inputPath = args(0) val conf = SparkHadoopUtil.get.newConfiguration() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 79cfedf332436..48e8d11cdf95b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -24,6 +24,9 @@ import org.apache.spark.SparkContext._ /** * K-means clustering. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.clustering.KMeans */ object SparkKMeans { @@ -46,11 +49,23 @@ object SparkKMeans { bestIndex } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of KMeans Clustering and is given as an example! + |Please use the KMeans method found in org.apache.spark.mllib.clustering + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + if (args.length < 3) { System.err.println("Usage: SparkKMeans ") System.exit(1) } + + showWarning() + val sparkConf = new SparkConf().setAppName("SparkKMeans") val sc = new SparkContext(sparkConf) val lines = sc.textFile(args(0)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 99ceb3089e9fe..fc23308fc4adf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -28,6 +28,9 @@ import org.apache.spark._ /** * Logistic regression based classification. * Usage: SparkLR [slices] + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.mllib.classification.LogisticRegression */ object SparkLR { val N = 10000 // Number of data points @@ -47,7 +50,18 @@ object SparkLR { Array.tabulate(N)(generatePoint) } + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { + + showWarning() + val sparkConf = new SparkConf().setAppName("SparkLR") val sc = new SparkContext(sparkConf) val numSlices = if (args.length > 0) args(0).toInt else 2 @@ -66,6 +80,7 @@ object SparkLR { } println("Final w: " + w) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b3cc361154198..43f13fe24f0d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -49,6 +49,7 @@ object DecisionTreeRunner { case class Params( input: String = null, algo: Algo = Classification, + numClassesForClassification: Int = 2, maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 100) @@ -68,6 +69,10 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numClassesForClassification") + .text(s"number of classes for classification, " + + s"default: ${defaultParams.numClassesForClassification}") + .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) @@ -118,7 +123,13 @@ object DecisionTreeRunner { case Variance => impurity.Variance } - val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) + val strategy + = new Strategy( + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = params.numClassesForClassification) val model = DecisionTree.train(training, strategy) if (params.algo == Classification) { @@ -139,12 +150,8 @@ object DecisionTreeRunner { */ private def accuracyScore( model: DecisionTreeModel, - data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector): Double = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b262fabbe0e0d..66a23fac39999 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -42,7 +42,7 @@ object HiveFromSpark { hql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0) + val count = hql("SELECT COUNT(*) FROM src").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 07ae88febf916..56d2886b26878 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -153,15 +153,15 @@ class FlumeReceiver( private def initServer() = { if (enableDecompression) { - val channelFactory = new NioServerSocketChannelFactory - (Executors.newCachedThreadPool(), Executors.newCachedThreadPool()); - val channelPipelieFactory = new CompressionChannelPipelineFactory() + val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), + Executors.newCachedThreadPool()) + val channelPipelineFactory = new CompressionChannelPipelineFactory() new NettyServer( responder, new InetSocketAddress(host, port), - channelFactory, - channelPipelieFactory, + channelFactory, + channelPipelineFactory, null) } else { new NettyServer(responder, new InetSocketAddress(host, port)) diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 5ea2e5549d7df..4eacc47da5699 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -63,7 +63,8 @@ class TwitterReceiver( storageLevel: StorageLevel ) extends Receiver[Status](storageLevel) with Logging { - private var twitterStream: TwitterStream = _ + @volatile private var twitterStream: TwitterStream = _ + @volatile private var stopped = false def onStart() { try { @@ -78,7 +79,9 @@ class TwitterReceiver( def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} def onException(e: Exception) { - restart("Error receiving tweets", e) + if (!stopped) { + restart("Error receiving tweets", e) + } } }) @@ -91,12 +94,14 @@ class TwitterReceiver( } setTwitterStream(newTwitterStream) logInfo("Twitter receiver started") + stopped = false } catch { case e: Exception => restart("Error starting Twitter stream", e) } } def onStop() { + stopped = true setTwitterStream(null) logInfo("Twitter receiver stopped") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 3507f358bfb40..fa4b891754c40 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -344,7 +344,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("webgraph") - * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees() + * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees * val graph = rawGraph.outerJoinVertices(outDeg) { * (vid, data, optDeg) => optDeg.getOrElse(0) * } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index f97f329c0e832..1948c978c30bf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -35,9 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: Kryo) { kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[MessageToPartition[Object]]) - kryo.register(classOf[VertexBroadcastMsg[Object]]) - kryo.register(classOf[RoutingTableMessage]) kryo.register(classOf[(VertexId, Object)]) kryo.register(classOf[EdgePartition[Object, Object]]) kryo.register(classOf[BitSet]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index edd5b79da1522..02afaa987d40d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -198,10 +198,10 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * * {{{ * val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph") - * .mapVertices(v => 0) - * val outDeg: RDD[(Int, Int)] = rawGraph.outDegrees - * val graph = rawGraph.leftJoinVertices[Int,Int](outDeg, - * (v, deg) => deg ) + * .mapVertices((_, _) => 0) + * val outDeg = rawGraph.outDegrees + * val graph = rawGraph.joinVertices[Int](outDeg) + * ((_, _, outDeg) => outDeg) * }}} * */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index ccdaa82eb9162..33f35cfb69a26 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -26,7 +26,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ -import org.apache.spark.graphx.impl.MsgRDDFunctions._ import org.apache.spark.graphx.util.BytecodeUtils @@ -83,15 +82,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val vdTag = classTag[VD] val newEdges = edges.withPartitionsRDD(edges.map { e => val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions) - - // Should we be using 3-tuple or an optimized class - new MessageToPartition(part, (e.srcId, e.dstId, e.attr)) + (part, (e.srcId, e.dstId, e.attr)) } .partitionBy(new HashPartitioner(numPartitions)) .mapPartitionsWithIndex( { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag) iter.foreach { message => - val data = message.data + val data = message._2 builder.add(data._1, data._2, data._3) } val edgePartition = builder.toEdgePartition diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index d85afa45b1264..5318b8da6412a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -25,82 +25,6 @@ import org.apache.spark.graphx.{PartitionID, VertexId} import org.apache.spark.rdd.{ShuffledRDD, RDD} -private[graphx] -class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T]( - @transient var partition: PartitionID, - var vid: VertexId, - var data: T) - extends Product2[PartitionID, (VertexId, T)] with Serializable { - - override def _1 = partition - - override def _2 = (vid, data) - - override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]] -} - - -/** - * A message used to send a specific value to a partition. - * @param partition index of the target partition. - * @param data value to send - */ -private[graphx] -class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T]( - @transient var partition: PartitionID, - var data: T) - extends Product2[PartitionID, T] with Serializable { - - override def _1 = partition - - override def _2 = data - - override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]] -} - - -private[graphx] -class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) { - def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = { - val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]]( - self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(new IntVertexBroadcastMsgSerializer) - } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(new LongVertexBroadcastMsgSerializer) - } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer) - } - rdd - } -} - - -private[graphx] -class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) { - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = { - new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner) - } - -} - -private[graphx] -object MsgRDDFunctions { - implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = { - new MsgRDDFunctions(rdd) - } - - implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = { - new VertexBroadcastMsgRDDFunctions(rdd) - } -} - private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 502b112d31c2e..a565d3b28bf52 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -/** - * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that - * the edge partition references `vid` in the specified `position` (src, dst, or both). -*/ -private[graphx] -class RoutingTableMessage( - var vid: VertexId, - var pid: PartitionID, - var position: Byte) - extends Product2[VertexId, (PartitionID, Byte)] with Serializable { - override def _1 = vid - override def _2 = (pid, position) - override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage] -} +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } @@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions { private[graphx] object RoutingTablePartition { + /** + * A message from an edge partition to a vertex specifying the position in which the edge + * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower + * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + */ + type RoutingTableMessage = (VertexId, Int) + + private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = { + val positionUpper2 = position << 30 + val pidLower30 = pid & 0x3FFFFFFF + (vid, positionUpper2 | pidLower30) + } + + private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1 + private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF + private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte + val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty) /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */ @@ -77,7 +81,9 @@ object RoutingTablePartition { map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => - new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2) + val vid = vidAndPosition._1 + val position = vidAndPosition._2 + toMessage(vid, pid, position) } } @@ -88,9 +94,12 @@ object RoutingTablePartition { val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) for (msg <- iter) { - pid2vid(msg.pid) += msg.vid - srcFlags(msg.pid) += (msg.position & 0x1) != 0 - dstFlags(msg.pid) += (msg.position & 0x2) != 0 + val vid = vidFromMessage(msg) + val pid = pidFromMessage(msg) + val position = positionFromMessage(msg) + pid2vid(pid) += vid + srcFlags(pid) += (position & 0x1) != 0 + dstFlags(pid) += (position & 0x2) != 0 } new RoutingTablePartition(pid2vid.zipWithIndex.map { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 033237f597216..3909efcdfc993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -24,9 +24,11 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.graphx._ import org.apache.spark.serializer._ +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage + private[graphx] class RoutingTableMessageSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { @@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleSerializationStream(s) { def writeObject[T: ClassTag](t: T): SerializationStream = { val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg.vid, optimizePositive = false) - writeUnsignedVarInt(msg.pid) - // TODO: Write only the bottom two bits of msg.position - s.write(msg.position) + writeVarLong(msg._1, optimizePositive = false) + writeInt(msg._2) this } } @@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleDeserializationStream(s) { override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - val c = s.read() - if (c == -1) throw new EOFException - new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T] + val b = readInt() + (a, b).asInstanceOf[T] } } } @@ -76,78 +74,6 @@ class VertexIdMsgSerializer extends Serializer with Serializable { } } -/** A special shuffle serializer for VertexBroadcastMessage[Int]. */ -private[graphx] -class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Int]] - writeVarLong(msg.vid, optimizePositive = false) - writeInt(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for VertexBroadcastMessage[Long]. */ -private[graphx] -class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Long]] - writeVarLong(msg.vid, optimizePositive = false) - writeLong(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readLong() - new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for VertexBroadcastMessage[Double]. */ -private[graphx] -class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Double]] - writeVarLong(msg.vid, optimizePositive = false) - writeDouble(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T] - } - } - } -} - /** A special shuffle serializer for AggregationMessage[Int]. */ private[graphx] class IntAggMsgSerializer extends Serializer with Serializable { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index ff17edeaf8f16..6aab28ff05355 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -30,7 +30,7 @@ package object graphx { */ type VertexId = Long - /** Integer identifer of a graph partition. */ + /** Integer identifer of a graph partition. Must be less than 2^30. */ // TODO: Consider using Char. type PartitionID = Int diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index 91caa6b605a1e..864cb1fdf0022 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -26,75 +26,11 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.impl.MsgRDDFunctions._ import org.apache.spark.serializer.SerializationStream class SerializerSuite extends FunSuite with LocalSparkContext { - test("IntVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Int](3, 4, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Long](3, 4, 5) - val bout = new ByteArrayOutputStream - val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - test("IntAggMsgSerializer") { val outMsg = (4: VertexId, 5) val bout = new ByteArrayOutputStream @@ -152,15 +88,6 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } } - test("TestShuffleVertexBroadcastMsg") { - withSpark { sc => - val bmsgs = sc.parallelize(0 until 100, 10).map { pid => - new VertexBroadcastMsg[Int](pid, pid, pid) - } - bmsgs.partitionBy(new HashPartitioner(3)).collect() - } - } - test("variable long encoding") { def testVarLongEncoding(v: Long, optimizePositive: Boolean) { val bout = new ByteArrayOutputStream diff --git a/make-distribution.sh b/make-distribution.sh index 94b473bf91cd3..c08093f46b61f 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -23,21 +23,6 @@ # The distribution contains fat (assembly) jars that include the Scala library, # so it is completely self contained. # It does not contain source or *.class files. -# -# Optional Arguments -# --tgz: Additionally creates spark-$VERSION-bin.tar.gz -# --hadoop VERSION: Builds against specified version of Hadoop. -# --with-yarn: Enables support for Hadoop YARN. -# --with-hive: Enable support for reading Hive tables. -# --name: A moniker for the release target. Defaults to the Hadoop verison. -# -# Recommended deploy/testing procedure (standalone mode): -# 1) Rsync / deploy the dist/ dir to one host -# 2) cd to deploy dir; ./sbin/start-master.sh -# 3) Verify master is up by visiting web page, ie http://master-ip:8080. Note the spark:// URL. -# 4) ./sbin/start-slave.sh 1 <> -# 5) ./bin/spark-shell --master spark://my-master-ip:7077 -# set -o pipefail set -e @@ -46,26 +31,35 @@ set -e FWDIR="$(cd `dirname $0`; pwd)" DISTDIR="$FWDIR/dist" -# Initialize defaults -SPARK_HADOOP_VERSION=1.0.4 -SPARK_YARN=false -SPARK_HIVE=false SPARK_TACHYON=false MAKE_TGZ=false NAME=none +function exit_with_usage { + echo "make-distribution.sh - tool for making binary distributions of Spark" + echo "" + echo "usage:" + echo "./make-distribution.sh [--name] [--tgz] [--with-tachyon] " + echo "See Spark's \"Building with Maven\" doc for correct Maven options." + echo "" + exit 1 +} + # Parse arguments while (( "$#" )); do case $1 in --hadoop) - SPARK_HADOOP_VERSION="$2" - shift + echo "Error: '--hadoop' is no longer supported:" + echo "Error: use Maven options -Phadoop.version and -Pyarn.version" + exit_with_usage ;; --with-yarn) - SPARK_YARN=true + echo "Error: '--with-yarn' is no longer supported, use Maven option -Pyarn" + exit_with_usage ;; --with-hive) - SPARK_HIVE=true + echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + exit_with_usage ;; --skip-java-test) SKIP_JAVA_TEST=true @@ -80,6 +74,12 @@ while (( "$#" )); do NAME="$2" shift ;; + --help) + exit_with_usage + ;; + *) + break + ;; esac shift done @@ -143,14 +143,6 @@ else echo "Making distribution for Spark $VERSION in $DISTDIR..." fi -echo "Hadoop version set to $SPARK_HADOOP_VERSION" -echo "Release name set to $NAME" -if [ "$SPARK_YARN" == "true" ]; then - echo "YARN enabled" -else - echo "YARN disabled" -fi - if [ "$SPARK_TACHYON" == "true" ]; then echo "Tachyon Enabled" else @@ -162,33 +154,12 @@ cd $FWDIR export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="mvn clean package" - -# Use special profiles for hadoop versions 0.23.x, 2.2.x, 2.3.x, 2.4.x -if [[ "$SPARK_HADOOP_VERSION" =~ ^0\.23\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-0.23"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.2\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.2"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.3\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.3"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.4\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.4"; fi -if [[ "$SPARK_HIVE" == "true" ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phive"; fi -if [[ "$SPARK_YARN" == "true" ]]; then - # For hadoop versions 0.23.x to 2.1.x, use the yarn-alpha profile - if [[ "$SPARK_HADOOP_VERSION" =~ ^0\.2[3-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^0\.[3-9][0-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^1\.[0-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^2\.[0-1]\. ]]; then - BUILD_COMMAND="$BUILD_COMMAND -Pyarn-alpha" - # For hadoop versions 2.2+, use the yarn profile - elif [[ "$SPARK_HADOOP_VERSION" =~ ^2.[2-9]. ]]; then - BUILD_COMMAND="$BUILD_COMMAND -Pyarn" - fi - BUILD_COMMAND="$BUILD_COMMAND -Dyarn.version=$SPARK_HADOOP_VERSION" -fi -BUILD_COMMAND="$BUILD_COMMAND -Dhadoop.version=$SPARK_HADOOP_VERSION" -BUILD_COMMAND="$BUILD_COMMAND -DskipTests" +BUILD_COMMAND="mvn clean package -DskipTests $@" # Actually build the jar echo -e "\nBuilding with..." echo -e "\$ $BUILD_COMMAND\n" + ${BUILD_COMMAND} # Make directories diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index c44173793b39a..954621ee8b933 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable { } } + private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { + require(bytes.length - offset == 8, "Wrong size byte array for Double") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + bb.getDouble + } + private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { val packetLength = bytes.length - offset require(packetLength >= 5, "Byte array too short") @@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable { Vectors.sparse(size, indices, values) } + /** + * Returns an 8-byte array for the input Double. + * + * Note: we currently do not use a magic byte for double for storage efficiency. + * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. + * The corresponding deserializer, deserializeDouble, needs to be modified as well if the + * serialization scheme changes. + */ + private[python] def serializeDouble(double: Double): Array[Byte] = { + val bytes = new Array[Byte](8) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(double) + bytes + } + private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { val len = doubles.length val bytes = new Array[Byte](5 + 8 * len) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 90aa8ac998ba9..2242329b7918e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class LogisticRegressionModel private[mllib] ( +class LogisticRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b6e0c4a80e27b..6c7be0a4f1dcb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] ( } } - override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict) + override def predict(testData: RDD[Vector]): RDD[Double] = { + val bcModel = testData.context.broadcast(this) + testData.mapPartitions { iter => + val model = bcModel.value + iter.map(model.predict) + } + } override def predict(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 316ecd713b715..80f8a1b2f1e84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class SVMModel private[mllib] ( +class SVMModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index de22fbb6ffc10..db425d866bbad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -165,18 +165,21 @@ class KMeans private ( val activeCenters = activeRuns.map(r => centers(r)).toArray val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) + val bcActiveCenters = sc.broadcast(activeCenters) + // Find the sum and count of points mapping to each center val totalContribs = data.mapPartitions { points => - val runs = activeCenters.length - val k = activeCenters(0).length - val dims = activeCenters(0)(0).vector.length + val thisActiveCenters = bcActiveCenters.value + val runs = thisActiveCenters.length + val k = thisActiveCenters(0).length + val dims = thisActiveCenters(0)(0).vector.length val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]]) val counts = Array.fill(runs, k)(0L) points.foreach { point => (0 until runs).foreach { i => - val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point) + val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point) costAccums(i) += cost sums(i)(bestCenter) += point.vector counts(i)(bestCenter) += 1 @@ -264,16 +267,17 @@ class KMeans private ( // to their squared distance from that run's current centers var step = 0 while (step < initializationSteps) { + val bcCenters = data.context.broadcast(centers) val sumCosts = data.flatMap { point => (0 until runs).map { r => - (r, KMeans.pointCost(centers(r), point)) + (r, KMeans.pointCost(bcCenters.value(r), point)) } }.reduceByKey(_ + _).collectAsMap() val chosen = data.mapPartitionsWithIndex { (index, points) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) points.flatMap { p => (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r) + rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) }.map((_, p)) } }.collect() @@ -286,9 +290,10 @@ class KMeans private ( // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them + val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => (0 until runs).map { r => - ((r, KMeans.findClosest(centers(r), p)._1), 1.0) + ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() val finalCenters = (0 until runs).map { r => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index fba21aefaaacd..5823cb6e52e7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm - points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1) + val bcCentersWithNorm = points.context.broadcast(centersWithNorm) + points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1) } /** Maps given points to their cluster indices. */ @@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser */ def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm - data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum() + val bcCentersWithNorm = data.context.broadcast(centersWithNorm) + data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum() } private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 2e3a4ce783de7..f0722d7c14a46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging { cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) j += 1 } - centers(i) = points(j-1).toDense + if (j == 0) { + logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." + + s" Using duplicate point for center k = $i.") + centers(i) = points(0).toDense + } else { + centers(i) = points(j - 1).toDense + } } // Run up to maxIterations iterations of Lloyd's algorithm diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 079743742d86d..1af40de2c7fcf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -103,11 +103,11 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) - val agg = counts.values.mapPartitions({ iter => + val agg = counts.values.mapPartitions { iter => val agg = new BinaryLabelCounter() iter.foreach(agg += _) Iterator(agg) - }, preservesPartitioning = true).collect() + }.collect() val partitionwiseCumulativeCounts = agg.scanLeft(new BinaryLabelCounter())( (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala new file mode 100644 index 0000000000000..666362ae6739a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import scala.collection.Map + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Matrices, Matrix} +import org.apache.spark.rdd.RDD + +/** + * ::Experimental:: + * Evaluator for multiclass classification. + * + * @param predictionAndLabels an RDD of (prediction, label) pairs. + */ +@Experimental +class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { + + private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() + private lazy val labelCount: Long = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Int] = predictionAndLabels + .map { case (prediction, label) => + (label, if (label == prediction) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val fpByClass: Map[Double, Int] = predictionAndLabels + .map { case (prediction, label) => + (prediction, if (prediction != label) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val confusions = predictionAndLabels + .map { case (prediction, label) => + ((label, prediction), 1) + }.reduceByKey(_ + _) + .collectAsMap() + + /** + * Returns confusion matrix: + * predicted classes are in columns, + * they are ordered by class label ascending, + * as in "labels" + */ + def confusionMatrix: Matrix = { + val n = labels.size + val values = Array.ofDim[Double](n * n) + var i = 0 + while (i < n) { + var j = 0 + while (j < n) { + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + j += 1 + } + i += 1 + } + Matrices.dense(n, n, values) + } + + /** + * Returns true positive rate for a given label (category) + * @param label the label. + */ + def truePositiveRate(label: Double): Double = recall(label) + + /** + * Returns false positive rate for a given label (category) + * @param label the label. + */ + def falsePositiveRate(label: Double): Double = { + val fp = fpByClass.getOrElse(label, 0) + fp.toDouble / (labelCount - labelCountByClass(label)) + } + + /** + * Returns precision for a given label (category) + * @param label the label. + */ + def precision(label: Double): Double = { + val tp = tpByClass(label) + val fp = fpByClass.getOrElse(label, 0) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } + + /** + * Returns recall for a given label (category) + * @param label the label. + */ + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + + /** + * Returns f-measure for a given label (category) + * @param label the label. + * @param beta the beta parameter. + */ + def fMeasure(label: Double, beta: Double): Double = { + val p = precision(label) + val r = recall(label) + val betaSqrd = beta * beta + if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r) + } + + /** + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def fMeasure(label: Double): Double = fMeasure(label, 1.0) + + /** + * Returns precision + */ + lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount + + /** + * Returns recall + * (equals to precision for multiclass classifier + * because sum of all false positives is equal to sum + * of all false negatives) + */ + lazy val recall: Double = precision + + /** + * Returns f-measure + * (equals to precision and recall because precision equals recall) + */ + lazy val fMeasure: Double = precision + + /** + * Returns weighted true positive rate + * (equals to precision, recall and f-measure) + */ + lazy val weightedTruePositiveRate: Double = weightedRecall + + /** + * Returns weighted false positive rate + */ + lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => + falsePositiveRate(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged recall + * (equals to precision, recall and f-measure) + */ + lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => + recall(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged precision + */ + lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => + precision(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged f-measure + * @param beta the beta parameter. + */ + def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => + fMeasure(category, beta) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged f1-measure + */ + lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => + fMeasure(category, 1.0) * count.toDouble / labelCount + }.sum + + /** + * Returns the sequence of labels in ascending order + */ + lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index f4c403bc7861c..8c2b044ea73f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -377,9 +377,9 @@ class RowMatrix( s"Only support dense matrix at this time but found ${B.getClass.getName}.") val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) - val AB = rows.mapPartitions({ iter => + val AB = rows.mapPartitions { iter => val Bi = Bb.value - iter.map(row => { + iter.map { row => val v = BDV.zeros[Double](k) var i = 0 while (i < k) { @@ -387,8 +387,8 @@ class RowMatrix( i += 1 } Vectors.fromBreeze(v) - }) - }, preservesPartitioning = true) + } + } new RowMatrix(AB, nRows, B.numCols) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 7030eeabe400a..9fd760bf78083 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -163,6 +163,7 @@ object GradientDescent extends Logging { // Initialize weights as a column vector var weights = Vectors.dense(initialWeights.toArray) + val n = weights.size /** * For the first iteration, the regVal will be initialized as sum of weight squares @@ -172,12 +173,13 @@ object GradientDescent extends Logging { weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 for (i <- 1 to numIterations) { + val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](weights.size), 0.0))( + .aggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad)) + val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 7bbed9c8fdbef..179cd4a3f1625 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -195,13 +195,14 @@ object LBFGS extends Logging { override def calculate(weights: BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. - val localData = data val localGradient = gradient + val n = weights.length + val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( + val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( - features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) + features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala new file mode 100644 index 0000000000000..7ecb409c4a91a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} + +/** + * :: Experimental :: + * Trait for random number generators that generate i.i.d. values from a distribution. + */ +@Experimental +trait DistributionGenerator extends Pseudorandom with Serializable { + + /** + * Returns an i.i.d. sample as a Double from an underlying distribution. + */ + def nextValue(): Double + + /** + * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the + * class when applicable for non-locking concurrent usage. + */ + def copy(): DistributionGenerator +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from U[0.0, 1.0] + */ +@Experimental +class UniformGenerator extends DistributionGenerator { + + // XORShiftRandom for better performance. Thread safety isn't necessary here. + private val random = new XORShiftRandom() + + override def nextValue(): Double = { + random.nextDouble() + } + + override def setSeed(seed: Long) = random.setSeed(seed) + + override def copy(): UniformGenerator = new UniformGenerator() +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from the standard normal distribution. + */ +@Experimental +class StandardNormalGenerator extends DistributionGenerator { + + // XORShiftRandom for better performance. Thread safety isn't necessary here. + private val random = new XORShiftRandom() + + override def nextValue(): Double = { + random.nextGaussian() + } + + override def setSeed(seed: Long) = random.setSeed(seed) + + override def copy(): StandardNormalGenerator = new StandardNormalGenerator() +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from the Poisson distribution with the given mean. + * + * @param mean mean for the Poisson distribution. + */ +@Experimental +class PoissonGenerator(val mean: Double) extends DistributionGenerator { + + private var rng = new Poisson(mean, new DRand) + + override def nextValue(): Double = rng.nextDouble() + + override def setSeed(seed: Long) { + rng = new Poisson(mean, new DRand(seed.toInt)) + } + + override def copy(): PoissonGenerator = new PoissonGenerator(mean) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala new file mode 100644 index 0000000000000..d7ee2d3f46846 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + */ +@Experimental +object RandomRDDGenerators { + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { + val uniform = new UniformGenerator() + randomRDD(sc, uniform, size, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { + uniformRDD(sc, size, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = { + uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + */ + @Experimental + def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { + val normal = new StandardNormalGenerator() + randomRDD(sc, normal, size, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + */ + @Experimental + def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { + normalRDD(sc, size, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + */ + @Experimental + def normalRDD(sc: SparkContext, size: Long): RDD[Double] = { + normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + */ + @Experimental + def poissonRDD(sc: SparkContext, + mean: Double, + size: Long, + numPartitions: Int, + seed: Long): RDD[Double] = { + val poisson = new PoissonGenerator(mean) + randomRDD(sc, poisson, size, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + */ + @Experimental + def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = { + poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param size Size of the RDD. + * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + */ + @Experimental + def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = { + poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Double] comprised of i.i.d. samples produced by generator. + */ + @Experimental + def randomRDD(sc: SparkContext, + generator: DistributionGenerator, + size: Long, + numPartitions: Int, + seed: Long): RDD[Double] = { + new RandomRDD(sc, size, numPartitions, generator, seed) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Double] comprised of i.i.d. samples produced by generator. + */ + @Experimental + def randomRDD(sc: SparkContext, + generator: DistributionGenerator, + size: Long, + numPartitions: Int): RDD[Double] = { + randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param size Size of the RDD. + * @return RDD[Double] comprised of i.i.d. samples produced by generator. + */ + @Experimental + def randomRDD(sc: SparkContext, + generator: DistributionGenerator, + size: Long): RDD[Double] = { + randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + } + + // TODO Generate RDD[Vector] from multivariate distributions. + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * uniform distribution on [0.0 1.0]. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformVectorRDD(sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): RDD[Vector] = { + val uniform = new UniformGenerator() + randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * uniform distribution on [0.0 1.0]. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformVectorRDD(sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): RDD[Vector] = { + uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * uniform distribution on [0.0 1.0]. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + */ + @Experimental + def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { + uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * standard normal distribution. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + */ + @Experimental + def normalVectorRDD(sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): RDD[Vector] = { + val uniform = new StandardNormalGenerator() + randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * standard normal distribution. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + */ + @Experimental + def normalVectorRDD(sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): RDD[Vector] = { + normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * standard normal distribution. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + */ + @Experimental + def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { + normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + */ + @Experimental + def poissonVectorRDD(sc: SparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): RDD[Vector] = { + val poisson = new PoissonGenerator(mean) + randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + */ + @Experimental + def poissonVectorRDD(sc: SparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: Int): RDD[Vector] = { + poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Poisson distribution with the input mean. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + */ + @Experimental + def poissonVectorRDD(sc: SparkContext, + mean: Double, + numRows: Long, + numCols: Int): RDD[Vector] = { + poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * input DistributionGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + */ + @Experimental + def randomVectorRDD(sc: SparkContext, + generator: DistributionGenerator, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): RDD[Vector] = { + new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * input DistributionGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + */ + @Experimental + def randomVectorRDD(sc: SparkContext, + generator: DistributionGenerator, + numRows: Long, + numCols: Int, + numPartitions: Int): RDD[Vector] = { + randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong) + } + + /** + * :: Experimental :: + * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * input DistributionGenerator. + * sc.defaultParallelism used for the number of partitions in the RDD. + * + * @param sc SparkContext used to create the RDD. + * @param generator DistributionGenerator used to populate the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + */ + @Experimental + def randomVectorRDD(sc: SparkContext, + generator: DistributionGenerator, + numRows: Long, + numCols: Int): RDD[Vector] = { + randomVectorRDD(sc, generator, numRows, numCols, + sc.defaultParallelism, Utils.random.nextLong) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala new file mode 100644 index 0000000000000..f13282d07ff92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.random.DistributionGenerator +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +import scala.util.Random + +private[mllib] class RandomRDDPartition(override val index: Int, + val size: Int, + val generator: DistributionGenerator, + val seed: Long) extends Partition { + + require(size >= 0, "Non-negative partition size required.") +} + +// These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue +private[mllib] class RandomRDD(@transient sc: SparkContext, + size: Long, + numPartitions: Int, + @transient rng: DistributionGenerator, + @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) { + + require(size > 0, "Positive RDD size required.") + require(numPartitions > 0, "Positive number of partitions required") + require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, + "Partition size cannot exceed Int.MaxValue") + + override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = { + val split = splitIn.asInstanceOf[RandomRDDPartition] + RandomRDD.getPointIterator(split) + } + + override def getPartitions: Array[Partition] = { + RandomRDD.getPartitions(size, numPartitions, rng, seed) + } +} + +private[mllib] class RandomVectorRDD(@transient sc: SparkContext, + size: Long, + vectorSize: Int, + numPartitions: Int, + @transient rng: DistributionGenerator, + @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { + + require(size > 0, "Positive RDD size required.") + require(numPartitions > 0, "Positive number of partitions required") + require(vectorSize > 0, "Positive vector size required.") + require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, + "Partition size cannot exceed Int.MaxValue") + + override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = { + val split = splitIn.asInstanceOf[RandomRDDPartition] + RandomRDD.getVectorIterator(split, vectorSize) + } + + override protected def getPartitions: Array[Partition] = { + RandomRDD.getPartitions(size, numPartitions, rng, seed) + } +} + +private[mllib] object RandomRDD { + + def getPartitions(size: Long, + numPartitions: Int, + rng: DistributionGenerator, + seed: Long): Array[Partition] = { + + val partitions = new Array[RandomRDDPartition](numPartitions) + var i = 0 + var start: Long = 0 + var end: Long = 0 + val random = new Random(seed) + while (i < numPartitions) { + end = ((i + 1) * size) / numPartitions + partitions(i) = new RandomRDDPartition(i, (end - start).toInt, rng, random.nextLong()) + start = end + i += 1 + } + partitions.asInstanceOf[Array[Partition]] + } + + // The RNG has to be reset every time the iterator is requested to guarantee same data + // every time the content of the RDD is examined. + def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = { + val generator = partition.generator.copy() + generator.setSeed(partition.seed) + Array.fill(partition.size)(generator.nextValue()).toIterator + } + + // The RNG has to be reset every time the iterator is requested to guarantee same data + // every time the content of the RDD is examined. + def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = { + val generator = partition.generator.copy() + generator.setSeed(partition.seed) + Array.fill(partition.size)(new DenseVector( + (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index cc56fd6ef28d6..5356790cb5339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -252,14 +252,14 @@ class ALS private ( val YtY = Some(sc.broadcast(computeYtY(users))) val previousProducts = products products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - userPartitioner, rank, lambda, alpha, YtY) + rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - productPartitioner, rank, lambda, alpha, XtX) + rank, lambda, alpha, XtX) previousUsers.unpersist() } } else { @@ -267,11 +267,11 @@ class ALS private ( // perform ALS update logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - userPartitioner, rank, lambda, alpha, YtY = None) + rank, lambda, alpha, YtY = None) products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - productPartitioner, rank, lambda, alpha, YtY = None) + rank, lambda, alpha, YtY = None) users.setName(s"users-$iter") } } @@ -430,7 +430,7 @@ class ALS private ( val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, true) + }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) inLinks.persist(StorageLevel.MEMORY_AND_DISK) @@ -464,7 +464,6 @@ class ALS private ( products: RDD[(Int, Array[Array[Double]])], productOutLinks: RDD[(Int, OutLinkBlock)], userInLinks: RDD[(Int, InLinkBlock)], - productPartitioner: Partitioner, rank: Int, lambda: Double, alpha: Double, @@ -477,7 +476,7 @@ class ALS private ( } } toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(productPartitioner) + }.groupByKey(new HashPartitioner(numUserBlocks)) .join(userInLinks) .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8cca926f1c92e..54854252d7477 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.regression -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.util.MLUtils._ /** * :: DeveloperApi :: @@ -57,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. val localWeights = weights + val bcWeights = testData.context.broadcast(localWeights) val localIntercept = intercept - - testData.map(v => predictPoint(v, localWeights, localIntercept)) + testData.mapPartitions { iter => + val w = bcWeights.value + iter.map(v => predictPoint(v, w, localIntercept)) + } } /** @@ -124,16 +126,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] run(input, initialWeights) } - /** Prepends one to the input vector. */ - private def prependOne(vector: Vector): Vector = { - val vector1 = vector.toBreeze match { - case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv) - case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv) - case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) - } - Vectors.fromBreeze(vector1) - } - /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. @@ -147,23 +139,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features))) + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - prependOne(initialWeights) + appendBias(initialWeights) } else { initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0 + val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 val weights = if (addIntercept) { - Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size)) + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a05dfc045fb8e..cb0d39e759a9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class LassoModel private[mllib] ( +class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 0ebad4eb58d88..8c078ec9f66e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.optimization._ * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class LinearRegressionModel private[mllib] ( +class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index bd983bac001a0..a826deb695ee1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.linalg.Vector * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class RidgeRegressionModel private[mllib] ( +class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala new file mode 100644 index 0000000000000..68f3867ba6c11 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.rdd.RDD + +/** + * API for statistical functions in MLlib + */ +@Experimental +object Statistics { + + /** + * Compute the Pearson correlation matrix for the input RDD of Vectors. + * Returns NaN if either vector has 0 variance. + * + * @param X an RDD[Vector] for which the correlation matrix is to be computed. + * @return Pearson correlation matrix comparing columns in X. + */ + def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) + + /** + * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman` + * + * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. + * + * @param X an RDD[Vector] for which the correlation matrix is to be computed. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return Correlation matrix comparing columns in X. + */ + def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) + + /** + * Compute the Pearson correlation for the input RDDs. + * Columns with 0 covariance produce NaN entries in the correlation matrix. + * + * @param x RDD[Double] of the same cardinality as y + * @param y RDD[Double] of the same cardinality as x + * @return A Double containing the Pearson correlation between the two input RDD[Double]s + */ + def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) + + /** + * Compute the correlation for the input RDDs using the specified method. + * Methods currently supported: pearson (default), spearman + * + * @param x RDD[Double] of the same cardinality as y + * @param y RDD[Double] of the same cardinality as x + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + *@return A Double containing the correlation between the two input RDD[Double]s using the + * specified method. + */ + def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala new file mode 100644 index 0000000000000..f23393d3da257 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} +import org.apache.spark.rdd.RDD + +/** + * Trait for correlation algorithms. + */ +private[stat] trait Correlation { + + /** + * Compute correlation for two datasets. + */ + def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double + + /** + * Compute the correlation matrix S, for the input matrix, where S(i, j) is the correlation + * between column i and j. S(i, j) can be NaN if the correlation is undefined for column i and j. + */ + def computeCorrelationMatrix(X: RDD[Vector]): Matrix + + /** + * Combine the two input RDD[Double]s into an RDD[Vector] and compute the correlation using the + * correlation implementation for RDD[Vector]. Can be NaN if correlation is undefined for the + * input vectors. + */ + def computeCorrelationWithMatrixImpl(x: RDD[Double], y: RDD[Double]): Double = { + val mat: RDD[Vector] = x.zip(y).map { case (xi, yi) => new DenseVector(Array(xi, yi)) } + computeCorrelationMatrix(mat)(0, 1) + } + +} + +/** + * Delegates computation to the specific correlation object based on the input method name + * + * Currently supported correlations: pearson, spearman. + * After new correlation algorithms are added, please update the documentation here and in + * Statistics.scala for the correlation APIs. + * + * Maintains the default correlation type, pearson + */ +private[stat] object Correlations { + + // Note: after new types of correlations are implemented, please update this map + val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) + val defaultCorrName: String = "pearson" + val defaultCorr: Correlation = nameToObjectMap(defaultCorrName) + + def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = { + val correlation = getCorrelationFromName(method) + correlation.computeCorrelation(x, y) + } + + def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = { + val correlation = getCorrelationFromName(method) + correlation.computeCorrelationMatrix(X) + } + + /** + * Match input correlation name with a known name via simple string matching + * + * private to stat for ease of unit testing + */ + private[stat] def getCorrelationFromName(method: String): Correlation = { + try { + nameToObjectMap(method) + } catch { + case nse: NoSuchElementException => + throw new IllegalArgumentException("Unrecognized method name. Supported correlations: " + + nameToObjectMap.keys.mkString(", ")) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala new file mode 100644 index 0000000000000..23b291eee070b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * Compute Pearson correlation for two RDDs of the type RDD[Double] or the correlation matrix + * for an RDD of the type RDD[Vector]. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + */ +private[stat] object PearsonCorrelation extends Correlation with Logging { + + /** + * Compute the Pearson correlation for two datasets. NaN if either vector has 0 variance. + */ + override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = { + computeCorrelationWithMatrixImpl(x, y) + } + + /** + * Compute the Pearson correlation matrix S, for the input matrix, where S(i, j) is the + * correlation between column i and j. 0 covariance results in a correlation value of Double.NaN. + */ + override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { + val rowMatrix = new RowMatrix(X) + val cov = rowMatrix.computeCovariance() + computeCorrelationMatrixFromCovariance(cov) + } + + /** + * Compute the Pearson correlation matrix from the covariance matrix. + * 0 covariance results in a correlation value of Double.NaN. + */ + def computeCorrelationMatrixFromCovariance(covarianceMatrix: Matrix): Matrix = { + val cov = covarianceMatrix.toBreeze.asInstanceOf[BDM[Double]] + val n = cov.cols + + // Compute the standard deviation on the diagonals first + var i = 0 + while (i < n) { + // TODO remove once covariance numerical issue resolved. + cov(i, i) = if (closeToZero(cov(i, i))) 0.0 else math.sqrt(cov(i, i)) + i +=1 + } + + // Loop through columns since cov is column major + var j = 0 + var sigma = 0.0 + var containNaN = false + while (j < n) { + sigma = cov(j, j) + i = 0 + while (i < j) { + val corr = if (sigma == 0.0 || cov(i, i) == 0.0) { + containNaN = true + Double.NaN + } else { + cov(i, j) / (sigma * cov(i, i)) + } + cov(i, j) = corr + cov(j, i) = corr + i += 1 + } + j += 1 + } + + // put 1.0 on the diagonals + i = 0 + while (i < n) { + cov(i, i) = 1.0 + i +=1 + } + + if (containNaN) { + logWarning("Pearson correlation matrix contains NaN values.") + } + + Matrices.fromBreeze(cov) + } + + private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { + math.abs(value) <= threshhold + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala new file mode 100644 index 0000000000000..1f7de630e778c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} +import org.apache.spark.rdd.{CoGroupedRDD, RDD} + +/** + * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix + * for an RDD of the type RDD[Vector]. + * + * Definition of Spearman's correlation can be found at + * http://en.wikipedia.org/wiki/Spearman's_rank_correlation_coefficient + */ +private[stat] object SpearmanCorrelation extends Correlation with Logging { + + /** + * Compute Spearman's correlation for two datasets. + */ + override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = { + computeCorrelationWithMatrixImpl(x, y) + } + + /** + * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the + * correlation between column i and j. + * + * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into + * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. + */ + override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { + val indexed = X.zipWithUniqueId() + + val numCols = X.first.size + if (numCols > 50) { + logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" + + " than 50 columns.") + } + val ranks = new Array[RDD[(Long, Double)]](numCols) + + // Note: we use a for loop here instead of a while loop with a single index variable + // to avoid race condition caused by closure serialization + for (k <- 0 until numCols) { + val column = indexed.map { case (vector, index) => (vector(k), index) } + ranks(k) = getRanks(column) + } + + val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) + PearsonCorrelation.computeCorrelationMatrix(ranksMat) + } + + /** + * Compute the ranks for elements in the input RDD, using the average method for ties. + * + * With the average method, elements with the same value receive the same rank that's computed + * by taking the average of their positions in the sorted list. + * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] + * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for + * ranks in the standard definition for Spearman's correlation. This does not affect the final + * results and is slightly more performant. + * + * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) + * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is + * copied from the input RDD. + */ + private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { + // Get elements' positions in the sorted list for computing average rank for duplicate values + val sorted = indexed.sortByKey().zipWithIndex() + + val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => + // add an extra element to signify the end of the list so that flatMap can flush the last + // batch of duplicates + val padded = iter ++ + Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) + var lastVal = 0.0 + var firstRank = 0.0 + val idBuffer = new ArrayBuffer[Long]() + padded.flatMap { case ((v, id), rank) => + if (v == lastVal && id != Long.MinValue) { + idBuffer += id + Iterator.empty + } else { + val entries = if (idBuffer.size == 0) { + // edge case for the first value matching the initial value of lastVal + Iterator.empty + } else if (idBuffer.size == 1) { + Iterator((idBuffer(0), firstRank)) + } else { + val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 + idBuffer.map(id => (id, averageRank)) + } + lastVal = v + firstRank = rank + idBuffer.clear() + idBuffer += id + entries + } + } + } + ranks + } + + private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { + val partitioner = new HashPartitioner(input.partitions.size) + val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) + cogrouped.map { + case (_, values: Array[Iterable[_]]) => + val doubles = values.asInstanceOf[Array[Iterable[Double]]] + new DenseVector(doubles.flatten.toArray) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 74d5d7ba10960..ad32e3f4560fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,11 +77,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = - strategy.algo match { - case Classification => 2 * numBins * numFeatures - case Regression => 3 * numBins * numFeatures - } + val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, + strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, + strategy.algo) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -109,8 +107,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins, maxLevelForSingleGroup) + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -212,7 +210,7 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + new DecisionTree(strategy).train(input) } /** @@ -233,10 +231,33 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + val strategy = new Strategy(algo, impurity, maxDepth) + new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) + new DecisionTree(strategy).train(input) + } /** * Method to train a decision tree model where the instances are represented as an RDD of @@ -250,6 +271,7 @@ object DecisionTree extends Serializable with Logging { * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -264,12 +286,13 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int, + numClassesForClassification: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo) + new DecisionTree(strategy).train(input) } private val InvalidBinIndex = -1 @@ -381,6 +404,14 @@ object DecisionTree extends Serializable with Logging { logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification + logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures + = strategy.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + + isMulticlassClassificationWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -436,10 +467,8 @@ object DecisionTree extends Serializable with Logging { /** * Find bin for one feature. */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean): Int = { + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -467,17 +496,28 @@ object DecisionTree extends Serializable with Logging { -1 } + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature(): Int = { - val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) - val category = bin.category + val categories = bin.highSplit.categories val features = labeledPoint.features - if (category == features(featureIndex)) { + if (categories.contains(features(featureIndex))) { return binIndex } binIndex += 1 @@ -494,7 +534,13 @@ object DecisionTree extends Serializable with Logging { binIndex } else { // Perform sequential search to find bin for categorical features. - val binIndex = sequentialBinSearchForCategoricalFeature() + val binIndex = { + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeatureInClassification() + } + } if (binIndex == -1){ throw new UnknownError("no bin was found for categorical variable.") } @@ -506,13 +552,16 @@ object DecisionTree extends Serializable with Logging { * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1. + * where b_ij is an integer between 0 and numBins - 1 for regressions and binary + * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) + // First element of the array is the label of the instance. arr(0) = labeledPoint.label + // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) @@ -525,8 +574,19 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) + } featureIndex += 1 } } @@ -535,18 +595,61 @@ object DecisionTree extends Serializable with Logging { arr } + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, + label: Double, featureIndex: Int) = { + + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val labelInt = label.toInt + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + } + + def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], + label: Double, agg: Array[Double], rightChildShift: Int) = { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val labelInt = label.toInt + if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + 1 + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + 1 + } + binIndex += 1 + } + } + /** * Performs a sequential aggregation over a partition for classification. For l nodes, * k features, either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures*numNodes for classification + * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -559,15 +662,52 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label match { - case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 - case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numClasses * numSplits * numFeatures * numNodes for classification + */ + def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = numClasses * numBins * numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } else { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, + rightChildShift) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } } featureIndex += 1 } @@ -586,7 +726,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 3 * numSplits * numFeatures * numNodes for regression */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -620,17 +760,20 @@ object DecisionTree extends Serializable with Logging { */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { - case Classification => classificationBinSeqOp(arr, agg) + case Classification => + if(isMulticlassClassificationWithCategoricalFeatures) { + unorderedClassificationBinSeqOp(arr, agg) + } else { + orderedClassificationBinSeqOp(arr, agg) + } case Regression => regressionBinSeqOp(arr, agg) } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = strategy.algo match { - case Classification => 2 * numBins * numFeatures * numNodes - case Regression => 3 * numBins * numFeatures * numNodes - } + val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, + isMulticlassClassificationWithCategoricalFeatures, strategy.algo) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -649,9 +792,6 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) @@ -668,42 +808,55 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], featureIndex: Int, splitIndex: Int, - rightNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) - val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) - val leftCount = left0Count + left1Count - - val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) - val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) - val rightCount = right0Count + right1Count + var classIndex = 0 + val leftCounts: Array[Double] = new Array[Double](numClasses) + val rightCounts: Array[Double] = new Array[Double](numClasses) + var leftTotalCount = 0.0 + var rightTotalCount = 0.0 + while (classIndex < numClasses) { + val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) + val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) + leftCounts(classIndex) = leftClassCount + leftTotalCount += leftClassCount + rightCounts(classIndex) = rightClassCount + rightTotalCount += rightClassCount + classIndex += 1 + } val impurity = { if (level > 0) { topImpurity } else { // Calculate impurity for root node. - strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + classIndex += 1 + } + strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } } - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + if (leftTotalCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + if (rightTotalCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) } - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) + val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) + val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) val gain = { if (level > 0) { @@ -713,17 +866,34 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (left1Count + right1Count) / (leftCount + rightCount) + val totalCount = leftTotalCount + rightTotalCount - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + // Sum of count for each label + val leftRightCounts: Array[Double] + = leftCounts.zip(rightCounts) + .map{case (leftCount, rightCount) => leftCount + rightCount} + + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) + else (maxIndex, maxValue, currentIndex + 1) + } + if (result._1 < 0) 0 else result._1 + } + + val predict = indexOfLargestArrayElement(leftRightCounts) + val prob = leftRightCounts(predict) / totalCount + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => - val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) - val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) - val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) val impurity = { if (level > 0) { @@ -768,104 +938,149 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], - * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + + def findAggForOrderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + var innerClassIndex = 0 + while (innerClassIndex < numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForUnorderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = numClasses * numBins * numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + splitIndex * numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForRegression( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + var i = 0 // index for regression histograms + while (i < 3) { // count, sum, sum^2 + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + + leftNodeAgg(featureIndex)(splitIndex - 1)(i) + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = + binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) + i += 1 + } + splitIndex += 1 + } + } + strategy.algo match { case Classification => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterate over all features. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(2 * (numBins - 2)) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) - - splitIndex += 1 + if (isMulticlassClassificationWithCategoricalFeatures){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + } + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } featureIndex += 1 } + (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = - binData(shift + (3 * (numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) - - splitIndex += 1 - } + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 } (leftNodeAgg, rightNodeAgg) @@ -876,8 +1091,8 @@ object DecisionTree extends Serializable with Logging { * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Double]], - rightNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -918,7 +1133,22 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - while (splitIndex < numBins - 1) { + val maxSplitIndex : Double = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { // Binary classification + featureCategories + } + } + } + while (splitIndex < maxSplitIndex) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -944,9 +1174,23 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - val shift = 2 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) - binsForNode + if (isMulticlassClassificationWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) @@ -963,14 +1207,26 @@ object DecisionTree extends Serializable with Logging { val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("node impurity = " + parentNodeImpurity) + logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } - bestSplits } + private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + algo match { + case Classification => + if (isMulticlassClassificationWithCategoricalFeatures) { + 2 * numClasses * numBins * numFeatures + } else { + numClasses * numBins * numFeatures + } + case Regression => 3 * numBins * numFeatures + } + } + /** * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -992,17 +1248,23 @@ object DecisionTree extends Serializable with Logging { val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) + /* - * TODO: Add a require statement ensuring #bins is always greater than the categories. + * Ensure #bins is always greater than the categories. For multiclass classification, + * #bins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins >= maxCategoriesForFeatures) + require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + + "in categorical features") } + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1036,48 +1298,93 @@ object DecisionTree extends Serializable with Logging { val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } - } else { - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + - "of bins") - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + + // Use different bin/split calculation strategy for categorical features in multiclass + // classification that satisfy the space constraint + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // 2^(maxFeatureValue- 1) - 1 combinations + var index = 0 + while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + val categories: List[Double] + = extractMultiClassCategories(index + 1, featureCategories) + splits(featureIndex)(index) + = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + new Bin( + new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), + Categorical, + Double.MinValue) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin( + splits(featureIndex)(index - 1), + splits(featureIndex)(index), + Categorical, + Double.MinValue) } } + index += 1 + } + } else { + + val centroidForCategories = { + if (isMulticlassClassification) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until featureCategories) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, + Categorical, categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } } } featureIndex += 1 @@ -1107,4 +1414,29 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 1b505fd76eb75..7c027ac2fda6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and the @@ -44,7 +46,15 @@ class Strategy ( val algo: Algo, val impurity: Impurity, val maxDepth: Int, + val numClassesForClassification: Int = 2, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128) extends Serializable { + + require(numClassesForClassification >= 2) + val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassWithCategoricalFeatures + = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 60f43e9278d2a..a0e2d91762782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -31,23 +31,35 @@ object Entropy extends Impurity { /** * :: DeveloperApi :: - * entropy calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return entropy value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - -(f0 * log2(f0)) - (f1 * log2(f1)) + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 0.0 + var classIndex = 0 + while (classIndex < numClasses) { + val classCount = counts(classIndex) + if (classCount != 0) { + val freq = classCount / totalCount + impurity -= freq * log2(freq) + } + classIndex += 1 } + impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c51d76d9b4c5b..48144b5e6d1e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -30,23 +30,32 @@ object Gini extends Impurity { /** * :: DeveloperApi :: - * Gini coefficient calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return Gini coefficient value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0 * f0 - f1 * f1 + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = counts(classIndex) / totalCount + impurity -= freq * freq + classIndex += 1 } + impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 8eab247cf0932..7b2a9320cc21d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -28,13 +28,13 @@ trait Impurity extends Serializable { /** * :: DeveloperApi :: - * information calculation for binary classification - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels * @return information value */ @DeveloperApi - def calculate(c0 : Double, c1 : Double): Double + def calculate(counts: Array[Double], totalCount: Double): Double /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 47d07122af30f..97149a99ead59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} */ @Experimental object Variance extends Impurity { - override def calculate(c0: Double, c1: Double): Double = + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value + */ + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 2d71e1e366069..c89c1e371a40e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin + * @param category categorical label value accepted in the bin for binary classification */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index cc8a24cce9614..fb12298e0f5d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi * @param leftImpurity left node impurity * @param rightImpurity right node impurity * @param predict predicted value + * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( @@ -34,10 +35,11 @@ class InformationGainStats( val impurity: Double, val leftImpurity: Double, val rightImpurity: Double, - val predict: Double) extends Serializable { + val predict: Double, + val prob: Double = 0.0) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index aaf92a1a8869a..30de24ad89f98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -264,8 +264,8 @@ object MLUtils { (1 to numFolds).map { fold => val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) - val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) - val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) + val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) + val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) (training, validation) }.toArray } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index faa675b59cd50..862221d48798a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -92,8 +92,6 @@ public void runLRUsingStaticMethods() { testRDD.rdd(), 100, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); - System.out.println(numAccurate); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 642843f90204c..d94cfa2fcec81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite { assert(q.features === p.features) } } + + test("double serialization") { + for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) { + val bytes = py.serializeDouble(x) + val deser = py.deserializeDouble(bytes) + assert(x === deser) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 44b757b6a1fb7..3f6ff859374c7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} object LogisticRegressionSuite { @@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } } + +class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = LogisticRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 516895d04222d..06cdd04f5fdae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} object NaiveBayesSuite { @@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext { validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } } + +class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 10 + val n = 200000 + val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map { i => + LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) + } + } + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = NaiveBayes.train(examples) + val predictions = model.predict(examples.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 886c71dde3af7..65e5df58db4c7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.mllib.classification -import scala.util.Random import scala.collection.JavaConversions._ - -import org.scalatest.FunSuite +import scala.util.Random import org.jblas.DoubleMatrix +import org.scalatest.FunSuite import org.apache.spark.SparkException -import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} object SVMSuite { @@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext { new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } } + +class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = SVMWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 560a4ad71a4de..34bc4537a7b3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.mllib.clustering +import scala.util.Random + import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} class KMeansSuite extends FunSuite with LocalSparkContext { - import KMeans.{RANDOM, K_MEANS_PARALLEL} + import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} test("single cluster") { val data = sc.parallelize(Array( @@ -38,29 +40,55 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // No matter how many runs or iterations we use, we should get one cluster, // centered at the mean of the points - var model = KMeans.train(data, k=1, maxIterations=1) + var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=2) + model = KMeans.train(data, k = 1, maxIterations = 2) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=5) + model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) assert(model.clusterCenters.head === center) model = KMeans.train( - data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head === center) } + test("no distinct points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0)), + 2) + val center = Vectors.dense(1.0, 2.0, 3.0) + + // Make sure code runs. + var model = KMeans.train(data, k=2, maxIterations=1) + assert(model.clusterCenters.size === 2) + } + + test("more clusters than points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 3.0, 4.0)), + 2) + + // Make sure code runs. + var model = KMeans.train(data, k=3, maxIterations=1) + assert(model.clusterCenters.size === 3) + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), @@ -74,26 +102,27 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.dense(1.0, 3.0, 4.0) - var model = KMeans.train(data, k=1, maxIterations=1) + var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.size === 1) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=2) + model = KMeans.train(data, k = 1, maxIterations = 2) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=5) + model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, + initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head === center) } @@ -119,25 +148,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) - var model = KMeans.train(data, k=1, maxIterations=1) + var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=2) + model = KMeans.train(data, k = 1, maxIterations = 2) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=5) + model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=5) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) assert(model.clusterCenters.head === center) - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, + initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head === center) data.unpersist() @@ -157,15 +187,15 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // it will make at least five passes, and it will give non-zero probability to each // unselected point as long as it hasn't yet selected all of them - var model = KMeans.train(rdd, k=5, maxIterations=1) + var model = KMeans.train(rdd, k = 5, maxIterations = 1) assert(Set(model.clusterCenters: _*) === Set(points: _*)) // Iterations of Lloyd's should not change the answer either - model = KMeans.train(rdd, k=5, maxIterations=10) + model = KMeans.train(rdd, k = 5, maxIterations = 10) assert(Set(model.clusterCenters: _*) === Set(points: _*)) // Neither should more runs - model = KMeans.train(rdd, k=5, maxIterations=10, runs=5) + model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) assert(Set(model.clusterCenters: _*) === Set(points: _*)) } @@ -194,3 +224,22 @@ class KMeansSuite extends FunSuite with LocalSparkContext { } } } + +class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble))) + }.cache() + for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) { + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = KMeans.train(points, 2, 2, 1, initMode) + val predictions = model.predict(points).collect() + val cost = model.computeCost(points) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 9d16182f9d8c4..94db1dc183230 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + + // TODO: move utility functions to TestingUtils. + + def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { + actual.zip(expected).forall { case (x1, x2) => + x1.almostEquals(x2) + } + } + + def elementsAlmostEqual( + actual: Seq[(Double, Double)], + expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { + actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => + x1.almostEquals(x2) && y1.almostEquals(y2) + } + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) @@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().toSeq === threshold) - assert(metrics.roc().collect().toSeq === rocCurve) - assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) - assert(metrics.pr().collect().toSeq === prCurve) - assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) - assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) - assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) - assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) - assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) + assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) + assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) + assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) + assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) + assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) + assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala new file mode 100644 index 0000000000000..1ea503971c864 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.mllib.util.LocalSparkContext + +class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { + test("Multiclass evaluation metrics") { + /* + * Confusion matrix for 3-class classification with total 9 instances: + * |2|1|1| true class0 (4 instances) + * |1|3|0| true class1 (4 instances) + * |0|0|1| true class2 (1 instance) + */ + val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabels = sc.parallelize( + Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), + (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) + val metrics = new MulticlassMetrics(predictionAndLabels) + val delta = 0.0000001 + val fpRate0 = 1.0 / (9 - 4) + val fpRate1 = 1.0 / (9 - 4) + val fpRate2 = 1.0 / (9 - 1) + val precision0 = 2.0 / (2 + 1) + val precision1 = 3.0 / (3 + 1) + val precision2 = 1.0 / (1 + 1) + val recall0 = 2.0 / (2 + 2) + val recall1 = 3.0 / (3 + 1) + val recall2 = 1.0 / (1 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) + assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) + assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) + assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + + assert(math.abs(metrics.recall - + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) + assert(math.abs(metrics.recall - metrics.precision) < delta) + assert(math.abs(metrics.recall - metrics.fMeasure) < delta) + assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedFalsePositiveRate - + ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) + assert(math.abs(metrics.weightedPrecision - + ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) + assert(math.abs(metrics.weightedRecall - + ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) + assert(math.abs(metrics.weightedFMeasure - + ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) + assert(math.abs(metrics.weightedFMeasure(2.0) - + ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) + assert(metrics.labels.sameElements(labels)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index a961f89456a18..325b817980f68 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite +import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} +import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} class RowMatrixSuite extends FunSuite with LocalSparkContext { @@ -193,3 +194,27 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { } } } + +class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { + + var mat: RowMatrix = _ + + override def beforeAll() { + super.beforeAll() + val m = 4 + val n = 200000 + val rows = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble()))) + } + mat = new RowMatrix(rows) + } + + test("task size should be small in svd") { + val svd = mat.computeSVD(1, computeU = true) + } + + test("task size should be small in summarize") { + val summary = mat.computeColumnSummaryStatistics() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 951b4f7c6e6f4..dfb2eb7f0d14e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.optimization -import scala.util.Random import scala.collection.JavaConversions._ +import scala.util.Random -import org.scalatest.FunSuite -import org.scalatest.Matchers +import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} object GradientDescentSuite { @@ -46,7 +45,7 @@ object GradientDescentSuite { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) - val unifRand = new scala.util.Random(45) + val unifRand = new Random(45) val rLogis = (0 until nPoints).map { i => val u = unifRand.nextDouble() math.log(u) - math.log(1.0-u) @@ -144,3 +143,26 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers "should be initialWeightsWithIntercept.") } } + +class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val (weights, loss) = GradientDescent.runMiniBatchSGD( + points, + new LogisticGradient, + new SquaredL2Updater, + 0.1, + 2, + 1.0, + 1.0, + Vectors.dense(new Array[Double](n))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index fe7a9033cd5f4..ff414742e8393 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.optimization -import org.scalatest.FunSuite -import org.scalatest.Matchers +import scala.util.Random + +import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { @@ -230,3 +231,24 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { "The weight differences between LBFGS and GD should be within 2%.") } } + +class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small") { + val m = 10 + val n = 200000 + val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble)))) + }.cache() + val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater) + .setNumCorrections(1) + .setConvergenceTol(1e-12) + .setMaxNumIterations(1) + .setRegParam(1.0) + val random = new Random(0) + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala new file mode 100644 index 0000000000000..974dec4c0b5ee --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import org.scalatest.FunSuite + +import org.apache.spark.util.StatCounter + +// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged +class DistributionGeneratorSuite extends FunSuite { + + def apiChecks(gen: DistributionGenerator) { + + // resetting seed should generate the same sequence of random numbers + gen.setSeed(42L) + val array1 = (0 until 1000).map(_ => gen.nextValue()) + gen.setSeed(42L) + val array2 = (0 until 1000).map(_ => gen.nextValue()) + assert(array1.equals(array2)) + + // newInstance should contain a difference instance of the rng + // i.e. setting difference seeds for difference instances produces different sequences of + // random numbers. + val gen2 = gen.copy() + gen.setSeed(0L) + val array3 = (0 until 1000).map(_ => gen.nextValue()) + gen2.setSeed(1L) + val array4 = (0 until 1000).map(_ => gen2.nextValue()) + // Compare arrays instead of elements since individual elements can coincide by chance but the + // sequences should differ given two different seeds. + assert(!array3.equals(array4)) + + // test that setting the same seed in the copied instance produces the same sequence of numbers + gen.setSeed(0L) + val array5 = (0 until 1000).map(_ => gen.nextValue()) + gen2.setSeed(0L) + val array6 = (0 until 1000).map(_ => gen2.nextValue()) + assert(array5.equals(array6)) + } + + def distributionChecks(gen: DistributionGenerator, + mean: Double = 0.0, + stddev: Double = 1.0, + epsilon: Double = 0.01) { + for (seed <- 0 until 5) { + gen.setSeed(seed.toLong) + val sample = (0 until 100000).map { _ => gen.nextValue()} + val stats = new StatCounter(sample) + assert(math.abs(stats.mean - mean) < epsilon) + assert(math.abs(stats.stdev - stddev) < epsilon) + } + } + + test("UniformGenerator") { + val uniform = new UniformGenerator() + apiChecks(uniform) + // Stddev of uniform distribution = (ub - lb) / math.sqrt(12) + distributionChecks(uniform, 0.5, 1 / math.sqrt(12)) + } + + test("StandardNormalGenerator") { + val normal = new StandardNormalGenerator() + apiChecks(normal) + distributionChecks(normal, 0.0, 1.0) + } + + test("PoissonGenerator") { + // mean = 0.0 will not pass the API checks since 0.0 is always deterministically produced. + for (mean <- List(1.0, 5.0, 100.0)) { + val poisson = new PoissonGenerator(mean) + apiChecks(poisson) + distributionChecks(poisson, mean, math.sqrt(mean), 0.1) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala new file mode 100644 index 0000000000000..6aa4f803df0f7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.util.StatCounter + +/* + * Note: avoid including APIs that do not set the seed for the RNG in unit tests + * in order to guarantee deterministic behavior. + * + * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged + */ +class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable { + + def testGeneratedRDD(rdd: RDD[Double], + expectedSize: Long, + expectedNumPartitions: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double = 0.01) { + val stats = rdd.stats() + assert(expectedSize === stats.count) + assert(expectedNumPartitions === rdd.partitions.size) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + // assume test RDDs are small + def testGeneratedVectorRDD(rdd: RDD[Vector], + expectedRows: Long, + expectedColumns: Int, + expectedNumPartitions: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double = 0.01) { + assert(expectedNumPartitions === rdd.partitions.size) + val values = new ArrayBuffer[Double]() + rdd.collect.foreach { vector => { + assert(vector.size === expectedColumns) + values ++= vector.toArray + }} + assert(expectedRows === values.size / expectedColumns) + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + test("RandomRDD sizes") { + + // some cases where size % numParts != 0 to test getPartitions behaves correctly + for ((size, numPartitions) <- List((10000, 6), (12345, 1), (1000, 101))) { + val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) + assert(rdd.count() === size) + assert(rdd.partitions.size === numPartitions) + + // check that partition sizes are balanced + val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble) + val partStats = new StatCounter(partSizes) + assert(partStats.max - partStats.min <= 1) + } + + // size > Int.MaxValue + val size = Int.MaxValue.toLong * 100L + val numPartitions = 101 + val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) + assert(rdd.partitions.size === numPartitions) + val count = rdd.partitions.foldLeft(0L) { (count, part) => + count + part.asInstanceOf[RandomRDDPartition].size + } + assert(count === size) + + // size needs to be positive + intercept[IllegalArgumentException] { new RandomRDD(sc, 0, 10, new UniformGenerator, 0L) } + + // numPartitions needs to be positive + intercept[IllegalArgumentException] { new RandomRDD(sc, 100, 0, new UniformGenerator, 0L) } + + // partition size needs to be <= Int.MaxValue + intercept[IllegalArgumentException] { + new RandomRDD(sc, Int.MaxValue.toLong * 100L, 99, new UniformGenerator, 0L) + } + } + + test("randomRDD for different distributions") { + val size = 100000L + val numPartitions = 10 + val poissonMean = 100.0 + + for (seed <- 0 until 5) { + val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed) + testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12)) + + val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed) + testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0) + + val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed) + testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1) + } + + // mock distribution to check that partitions have unique seeds + val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) + assert(random.collect.size === random.collect.distinct.size) + } + + test("randomVectorRDD for different distributions") { + val rows = 1000L + val cols = 100 + val parts = 10 + val poissonMean = 100.0 + + for (seed <- 0 until 5) { + val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed) + testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12)) + + val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed) + testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0) + + val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) + testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) + } + } +} + +private[random] class MockDistro extends DistributionGenerator { + + var seed = 0L + + // This allows us to check that each partition has a different seed + override def nextValue(): Double = seed.toDouble + + override def setSeed(seed: Long) = this.seed = seed + + override def copy(): MockDistro = new MockDistro +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index bfa42959c8ead..7aa96421aed87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression +import scala.util.Random + import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, + LocalSparkContext} class LassoSuite extends FunSuite with LocalSparkContext { @@ -113,3 +116,19 @@ class LassoSuite extends FunSuite with LocalSparkContext { validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } } + +class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = LassoWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 7aaad7d7a3e39..4f89112b650c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression +import scala.util.Random + import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, + LocalSparkContext} class LinearRegressionSuite extends FunSuite with LocalSparkContext { @@ -122,3 +125,19 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } } + +class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = LinearRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 67768e17fbe6d..727bbd051ff15 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite +import scala.util.Random import org.jblas.DoubleMatrix +import org.scalatest.FunSuite -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, + LocalSparkContext} class RidgeRegressionSuite extends FunSuite with LocalSparkContext { @@ -73,3 +76,19 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } } + +class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { + + test("task size should be small in both training and prediction") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = RidgeRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala new file mode 100644 index 0000000000000..bce4251426df7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, + SpearmanCorrelation} +import org.apache.spark.mllib.util.LocalSparkContext + +class CorrelationSuite extends FunSuite with LocalSparkContext { + + // test input data + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + test("corr(x, y) default, pearson") { + val x = sc.parallelize(xData) + val y = sc.parallelize(yData) + val expected = 0.6546537 + val default = Statistics.corr(x, y) + val p1 = Statistics.corr(x, y, "pearson") + assert(approxEqual(expected, default)) + assert(approxEqual(expected, p1)) + } + + test("corr(x, y) spearman") { + val x = sc.parallelize(xData) + val y = sc.parallelize(yData) + val expected = 0.5 + val s1 = Statistics.corr(x, y, "spearman") + assert(approxEqual(expected, s1)) + } + + test("corr(X) default, pearson") { + val X = sc.parallelize(data) + val defaultMat = Statistics.corr(X) + val pearsonMat = Statistics.corr(X, "pearson") + val expected = BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN,1.0000000)) + assert(matrixApproxEqual(defaultMat.toBreeze, expected)) + assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) + } + + test("corr(X) spearman") { + val X = sc.parallelize(data) + val spearmanMat = Statistics.corr(X, "spearman") + val expected = BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) + } + + test("method identification") { + val pearson = PearsonCorrelation + val spearman = SpearmanCorrelation + + assert(Correlations.getCorrelationFromName("pearson") === pearson) + assert(Correlations.getCorrelationFromName("spearman") === spearman) + + // Should throw IllegalArgumentException + try { + Correlations.getCorrelationFromName("kendall") + assert(false) + } catch { + case ie: IllegalArgumentException => + } + } + + def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = { + if (v1.isNaN) { + v2.isNaN + } else { + math.abs(v1 - v2) <= threshold + } + } + + def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { + for (i <- 0 until A.rows; j <- 0 until A.cols) { + if (!approxEqual(A(i, j), B(i, j), threshold)) { + println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) + return false + } + } + true + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bcb11876b8f4f..5961a618c59d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split @@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) @@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) // Check splits. @@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(3) === null) } + test("extract categories from a number for multiclass classification") { + val l = DecisionTree.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + } + + test("split and bin calculations for unordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Expecting 2^2 - 1 = 3 bins/splits + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(0.0)) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 1) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 1) + assert(splits(1)(1).categories.contains(1.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 2) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 2) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(1.0)) + + assert(splits(0)(3) === null) + assert(splits(1)(3) === null) + + + // Check bins. + + assert(bins(0)(0).category === Double.MinValue) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(0.0)) + assert(bins(1)(0).category === Double.MinValue) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(0)(1).category === Double.MinValue) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(0.0)) + assert(bins(0)(1).highSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(1)(1).category === Double.MinValue) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 1) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(0)(2).category === Double.MinValue) + assert(bins(0)(2).lowSplit.categories.length === 1) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.length === 2) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).category === Double.MinValue) + assert(bins(1)(2).lowSplit.categories.length === 1) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 2) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + + assert(bins(0)(3) === null) + assert(bins(1)(3) === null) + + } + + test("split and bin calculations for ordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // 2^10 - 1 > 100, so categorical variables will be ordered + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(2.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(2.0)) + assert(splits(0)(2).categories.contains(1.0)) + + assert(splits(0)(10) === null) + assert(splits(1)(10) === null) + + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(1).category === 2.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(2.0)) + + assert(bins(0)(10) === null) + + } + + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, + numClassesForClassification = 2, maxDepth = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.5) - assert(stats.predict < 0.7) + assert(stats.predict === 1) + assert(stats.prob == 0.6) assert(stats.impurity > 0.2) } @@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.5) - assert(stats.predict < 0.7) + assert(stats.predict == 0.6) assert(stats.impurity > 0.2) } @@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("test second level node building with/without groups") { + test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("stump with categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + } + + test("stump with continuous variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) + + } + + test("stump with continuous + categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) + } + + test("stump with categorical variables for ordered multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1.0)) + assert(bestSplit.featureType === Categorical) + } + + } object DecisionTreeSuite { @@ -473,4 +707,47 @@ object DecisionTreeSuite { } arr } + + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + + def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 2000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i)) + } else { + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i)) + } + } + arr + } + + def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): + Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala new file mode 100644 index 0000000000000..5e9101cdd3804 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.{Suite, BeforeAndAfterAll} + +import org.apache.spark.{SparkConf, SparkContext} + +trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => + @transient var sc: SparkContext = _ + + override def beforeAll() { + val conf = new SparkConf() + .setMaster("local-cluster[2, 1, 512]") + .setAppName("test-cluster") + .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data + sc = new SparkContext(conf) + super.beforeAll() + } + + override def afterAll() { + if (sc != null) { + sc.stop() + } + super.afterAll() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala index 0d4868f3d9e42..7857d9e5ee5c4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala @@ -20,13 +20,16 @@ package org.apache.spark.mllib.util import org.scalatest.Suite import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { - sc = new SparkContext("local", "test") + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + sc = new SparkContext(conf) super.beforeAll() } diff --git a/pom.xml b/pom.xml index f163b0cd1d972..972387605068c 100644 --- a/pom.xml +++ b/pom.xml @@ -303,6 +303,11 @@ snappy-java 1.0.5 + + net.jpountz.lz4 + lz4 + 1.2.0 + com.clearspring.analytics stream @@ -615,7 +620,6 @@ net.java.dev.jets3t jets3t ${jets3t.version} - runtime commons-logging @@ -959,6 +963,30 @@ org.apache.maven.plugins maven-source-plugin + + org.scalastyle + scalastyle-maven-plugin + 0.4.0 + + false + true + false + false + ${basedir}/src/main/scala + ${basedir}/src/test/scala + scalastyle-config.xml + scalastyle-output.xml + UTF-8 + + + + package + + check + + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d67c6571a0623..5ff88f0dd1cac 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -59,10 +59,16 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry"), + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + // Renamed putValues -> putArray + putIterator + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.putValues"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" - + "createZero$1") + "org.apache.spark.storage.DiskStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.TachyonStore.putValues") ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") @@ -73,7 +79,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. "org.apache.spark.mllib.recommendation.ALS.this"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7") + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") ) ++ MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ @@ -81,7 +89,15 @@ object MimaExcludes { MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 44abbc152f99f..62576f84dd031 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -19,7 +19,9 @@ import scala.util.Properties import scala.collection.JavaConversions._ import sbt._ +import sbt.Classpaths.publishTask import sbt.Keys._ +import sbtunidoc.Plugin.genjavadocSettings import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -103,12 +105,23 @@ object SparkBuild extends PomBuild { override val userPropertiesMap = System.getProperties.toMap - lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ Seq ( + lazy val MavenCompile = config("m2r") extend(Compile) + lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") + + lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq ( javaHome := Properties.envOrNone("JAVA_HOME").map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - publishMavenStyle := true + publishMavenStyle := true, + + otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), + publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { + (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) + }, + publishMavenStyle in MavenCompile := true, + publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), + publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn ) /** Following project only exists to pull previous artifacts of Spark for generating @@ -154,6 +167,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Spark SQL Core console settings */ + enable(SQL.settings)(sql) + /* Hive console settings */ enable(Hive.settings)(hive) @@ -167,6 +183,27 @@ object SparkBuild extends PomBuild { } +object SQL { + + lazy val settings = Seq( + + initialCommands in console := + """ + |import org.apache.spark.sql.catalyst.analysis._ + |import org.apache.spark.sql.catalyst.dsl._ + |import org.apache.spark.sql.catalyst.errors._ + |import org.apache.spark.sql.catalyst.expressions._ + |import org.apache.spark.sql.catalyst.plans.logical._ + |import org.apache.spark.sql.catalyst.rules._ + |import org.apache.spark.sql.catalyst.types._ + |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.execution + |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + ) + +} + object Hive { lazy val settings = Seq( diff --git a/python/epydoc.conf b/python/epydoc.conf index b73860bad8263..51c0faf359939 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -35,4 +35,4 @@ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests pyspark.rddsampler pyspark.daemon pyspark.mllib._common - pyspark.mllib.tests + pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 07df8697bd1a8..312c75d112cbf 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -59,4 +59,5 @@ from pyspark.storagelevel import StorageLevel -__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", "SparkFiles", "StorageLevel", "Row"] +__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", + "SparkFiles", "StorageLevel", "Row"] diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index eb5dbb8de2b39..4fda2a9b950b8 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -243,10 +243,10 @@ def save_function(self, obj, name=None, pack=struct.pack): # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: #Force server to import modules that have been imported in main modList = None - if themodule == None and not self.savedForceImports: + if themodule is None and not self.savedForceImports: mainmod = sys.modules['__main__'] if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): modList = list(mainmod.___pyc_forcedImports__) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 8eff4a242a529..b4c82f519bd53 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -30,7 +30,7 @@ u'local' >>> sc.appName u'My app' ->>> sc.sparkHome == None +>>> sc.sparkHome is None True >>> conf = SparkConf(loadDefaults=False) @@ -50,7 +50,8 @@ spark.executorEnv.VAR4=value4 spark.home=/path >>> sorted(conf.getAll(), key=lambda p: p[0]) -[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ +(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] """ @@ -99,6 +100,12 @@ def set(self, key, value): self._jconf.set(key, unicode(value)) return self + def setIfMissing(self, key, value): + """Set a configuration property, if not already set.""" + if self.get(key) is None: + self.set(key, value) + return self + def setMaster(self, value): """Set master URL to connect to.""" self._jconf.setMaster(value) @@ -116,11 +123,11 @@ def setSparkHome(self, value): def setExecutorEnv(self, key=None, value=None, pairs=None): """Set an environment variable to be passed to executors.""" - if (key != None and pairs != None) or (key == None and pairs == None): + if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") - elif key != None: + elif key is not None: self._jconf.setExecutorEnv(key, value) - elif pairs != None: + elif pairs is not None: for (k, v) in pairs: self._jconf.setExecutorEnv(k, v) return self @@ -137,7 +144,7 @@ def setAll(self, pairs): def get(self, key, defaultValue=None): """Get the configured value for some key, or return a default otherwise.""" - if defaultValue == None: # Py4J doesn't call the right get() if we pass None + if defaultValue is None: # Py4J doesn't call the right get() if we pass None if not self._jconf.contains(key): return None return self._jconf.get(key) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 95c54e7a5ad63..830a6ee03f2a6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -37,6 +37,15 @@ from py4j.java_collections import ListConverter +# These are special default configs for PySpark, they will overwrite +# the default ones for Spark if they are not configured by user. +DEFAULT_CONFIGS = { + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.serializer.objectStreamReset": 100, + "spark.rdd.compress": True, +} + + class SparkContext(object): """ Main entry point for Spark functionality. A SparkContext represents the @@ -50,12 +59,11 @@ class SparkContext(object): _next_accum_id = 0 _active_spark_context = None _lock = Lock() - _python_includes = None # zip and egg files that need to be added to PYTHONPATH - + _python_includes = None # zip and egg files that need to be added to PYTHONPATH def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + gateway=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -92,7 +100,16 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, tempNamedTuple = namedtuple("Callsite", "function file linenum") self._callsite = tempNamedTuple(function=None, file=None, linenum=None) SparkContext._ensure_initialized(self, gateway=gateway) - + try: + self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, + conf) + except: + # If an error occurs, clean up in order to allow future SparkContext creation: + self.stop() + raise + + def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, + conf): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -113,6 +130,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, if environment: for key, value in environment.iteritems(): self._conf.setExecutorEnv(key, value) + for key, value in DEFAULT_CONFIGS.items(): + self._conf.setIfMissing(key, value) # Check that we have at least the required parameters if not self._conf.contains("spark.master"): @@ -138,8 +157,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address self._javaAccumulator = self._jsc.accumulator( - self._jvm.java.util.ArrayList(), - self._jvm.PythonAccumulatorParam(host, port)) + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') @@ -165,7 +184,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, (dirname, filename) = os.path.split(path) self._python_includes.append(filename) sys.path.append(path) - if not dirname in sys.path: + if dirname not in sys.path: sys.path.append(dirname) # Create a temporary directory inside spark.local.dir: @@ -192,15 +211,19 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: - if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: + if (SparkContext._active_spark_context and + SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite # Raise error if there is already a running Spark context - raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \ - " created by %s at %s:%s " \ - % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) + raise ValueError( + "Cannot run multiple SparkContexts at once; " + "existing SparkContext(app=%s, master=%s)" + " created by %s at %s:%s " + % (currentAppName, currentMaster, + callsite.function, callsite.file, callsite.linenum)) else: SparkContext._active_spark_context = instance @@ -213,6 +236,13 @@ def setSystemProperty(cls, key, value): SparkContext._ensure_initialized() SparkContext._jvm.java.lang.System.setProperty(key, value) + @property + def version(self): + """ + The version of Spark on which this application is running. + """ + return self._jsc.version() + @property def defaultParallelism(self): """ @@ -228,17 +258,14 @@ def defaultMinPartitions(self): """ return self._jsc.sc().defaultMinPartitions() - def __del__(self): - self.stop() - def stop(self): """ Shut down the SparkContext. """ - if self._jsc: + if getattr(self, "_jsc", None): self._jsc.stop() self._jsc = None - if self._accumulatorServer: + if getattr(self, "_accumulatorServer", None): self._accumulatorServer.shutdown() self._accumulatorServer = None with SparkContext._lock: @@ -290,7 +317,7 @@ def textFile(self, name, minPartitions=None): Read a text file from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI, and return it as an RDD of Strings. - + >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: ... testFile.write("Hello world!") @@ -584,11 +611,12 @@ def addPyFile(self, path): HTTP, HTTPS or FTP URI. """ self.addFile(path) - (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix + (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode + # for tests in local mode + sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ @@ -649,9 +677,9 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result - in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure - that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, - where HDFS may respond to Thread.interrupt() by marking nodes as dead. + in Thread.interrupt() being called on the job's executor threads. This is useful to help + ensure that the tasks are actually stopped in a timely manner, but is off by default due + to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead. """ self._jsc.setJobGroup(groupId, description, interruptOnCancel) @@ -688,7 +716,7 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() - def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, returning the result as an array of elements. @@ -703,7 +731,7 @@ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True) [0, 1, 16, 25] """ - if partitions == None: + if partitions is None: partitions = range(rdd._jrdd.partitions().size()) javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) @@ -714,6 +742,7 @@ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) + def _test(): import atexit import doctest diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 5eb1c63bf206b..8a5873ded2b8b 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -42,12 +42,12 @@ def should_exit(): def compute_real_exit_code(exit_code): - # SystemExit's code can be integer or string, but os._exit only accepts integers - import numbers - if isinstance(exit_code, numbers.Integral): - return exit_code - else: - return 1 + # SystemExit's code can be integer or string, but os._exit only accepts integers + import numbers + if isinstance(exit_code, numbers.Integral): + return exit_code + else: + return 1 def worker(listen_sock): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2a17127a7e0f9..2c129679f47f3 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -24,6 +24,7 @@ from threading import Thread from py4j.java_gateway import java_import, JavaGateway, GatewayClient + def launch_gateway(): SPARK_HOME = os.environ["SPARK_HOME"] diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 5f3a7e71f7866..b0f1cc1927066 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -33,10 +33,11 @@ from pyspark.resultiterable import ResultIterable + def _do_python_join(rdd, other, numPartitions, dispatch): vs = rdd.map(lambda (k, v): (k, (1, v))) ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x : dispatch(x.__iter__())) + return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) def python_join(rdd, other, numPartitions): @@ -85,6 +86,7 @@ def make_mapper(i): vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) + def dispatch(seq): bufs = [[] for i in range(rdd_len)] for (n, v) in seq: diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e609b60a0f968..8e3ad6b783b6c 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -72,9 +72,9 @@ # Python interpreter must agree on what endian the machine is. -DENSE_VECTOR_MAGIC = 1 +DENSE_VECTOR_MAGIC = 1 SPARSE_VECTOR_MAGIC = 2 -DENSE_MATRIX_MAGIC = 3 +DENSE_MATRIX_MAGIC = 3 LABELED_POINT_MAGIC = 4 @@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64): return ar.copy() +def _serialize_double(d): + """ + Serialize a double (float or numpy.float64) into a mutually understood format. + """ + if type(d) == float or type(d) == float64: + d = float64(d) + ba = bytearray(8) + _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64) + return ba + else: + raise TypeError("_serialize_double called on non-float input") + + def _serialize_double_vector(v): - """Serialize a double vector into a mutually understood format. + """ + Serialize a double vector into a mutually understood format. + + Note: we currently do not use a magic byte for double for storage + efficiency. This should be reconsidered when we add Ser/De for other + 8-byte types (e.g. Long), for safety. The corresponding deserializer, + _deserialize_double, needs to be modified as well if the serialization + scheme changes. >>> x = array([1,2,3]) >>> y = _deserialize_double_vector(_serialize_double_vector(x)) @@ -148,6 +168,28 @@ def _serialize_sparse_vector(v): return ba +def _deserialize_double(ba, offset=0): + """Deserialize a double from a mutually understood format. + + >>> import sys + >>> _deserialize_double(_serialize_double(123.0)) == 123.0 + True + >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0 + True + >>> x = sys.float_info.max + >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x + True + >>> y = float64(sys.float_info.max) + >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y + True + """ + if type(ba) != bytearray: + raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) + if len(ba) - offset != 8: + raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) + return struct.unpack("d", ba[offset:])[0] + + def _deserialize_double_vector(ba, offset=0): """Deserialize a double vector from a mutually understood format. @@ -164,7 +206,7 @@ def _deserialize_double_vector(ba, offset=0): nb = len(ba) - offset if nb < 5: raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % nb) + "which is too short" % nb) if ba[offset] == DENSE_VECTOR_MAGIC: return _deserialize_dense_vector(ba, offset) elif ba[offset] == SPARSE_VECTOR_MAGIC: @@ -272,6 +314,7 @@ def _serialize_labeled_point(p): header_float[0] = p.label return header + serialized_features + def _deserialize_labeled_point(ba, offset=0): """Deserialize a LabeledPoint from a mutually understood format.""" from pyspark.mllib.regression import LabeledPoint @@ -283,6 +326,7 @@ def _deserialize_labeled_point(ba, offset=0): features = _deserialize_double_vector(ba, offset + 9) return LabeledPoint(label, features) + def _copyto(array, buffer, offset, shape, dtype): """ Copy the contents of a vector to a destination bytearray at the diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 1c0c536c4fb3d..9e28dfbb9145d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -63,7 +63,10 @@ class LogisticRegressionModel(LinearModel): def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept - prob = 1/(1 + exp(-margin)) + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + prob = 1 - 1 / (1 + exp(margin)) return 1 if prob > 0.5 else 0 diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index db39ed0acdb66..71f4ad1a8d44e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -247,6 +247,7 @@ def stringify(vector): else: return "[" + ",".join([str(v) for v in vector]) + "]" + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index e24c144f458bd..a707a9dcd5b49 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -24,7 +24,6 @@ from pyspark.serializers import NoOpSerializer - class MLUtils: """ Helper methods to load, save and pre-process data used in MLlib. @@ -154,7 +153,6 @@ def saveAsLibSVMFile(data, dir): lines = data.map(lambda p: MLUtils._convert_labeled_point_to_libsvm(p)) lines.saveAsTextFile(dir) - @staticmethod def loadLabeledPoints(sc, path, minPartitions=None): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0c35c666805dd..b84d976114f0d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,15 +39,46 @@ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler +from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ + get_used_memory from py4j.java_collections import ListConverter, MapConverter __all__ = ["RDD"] +# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized +# hash for string +def portable_hash(x): + """ + This function returns consistant hash code for builtin types, especially + for None and tuple with None. + + The algrithm is similar to that one used by CPython 2.7 + + >>> portable_hash(None) + 0 + >>> portable_hash((None, 1)) + 219750521 + """ + if x is None: + return 0 + if isinstance(x, tuple): + h = 0x345678 + for i in x: + h ^= portable_hash(i) + h *= 1000003 + h &= 0xffffffff + h ^= len(x) + if h == -1: + h = -2 + return h + return hash(x) + + def _extract_concise_traceback(): """ This function returns the traceback info for a callsite, returns a dict @@ -168,6 +199,22 @@ def _replaceRoot(self, value): self._sink(1) +def _parse_memory(s): + """ + Parse a memory string in the format supported by Java (e.g. 1g, 200m) and + return the value in MB + + >>> _parse_memory("256m") + 256 + >>> _parse_memory("2g") + 2048 + """ + units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} + if s[-1] not in units: + raise ValueError("invalid format: " + s) + return int(float(s[:-1]) * units[s[-1].lower()]) + + class RDD(object): """ @@ -202,10 +249,10 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY}). + Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True - self._jrdd.cache() + self.persist(StorageLevel.MEMORY_ONLY_SER) return self def persist(self, storageLevel): @@ -364,7 +411,7 @@ def sample(self, withReplacement, fraction, seed=None): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ - assert fraction >= 0.0, "Invalid fraction value: %s" % fraction + assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -1164,7 +1211,9 @@ def rightOuterJoin(self, other, numPartitions=None): return python_right_outer_join(self, other, numPartitions) # TODO: add option to control map-side combining - def partitionBy(self, numPartitions, partitionFunc=None): + # portable_hash is used as default, because builtin hash of None is different + # cross machines. + def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -1176,22 +1225,49 @@ def partitionBy(self, numPartitions, partitionFunc=None): if numPartitions is None: numPartitions = self._defaultReducePartitions() - if partitionFunc is None: - partitionFunc = lambda x: 0 if x is None else hash(x) - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numPartitions) objects - # to Java. Each object is a (splitNumber, [objects]) pair. + # Transferring O(n) objects to Java is too expensive. + # Instead, we'll form the hash buckets in Python, + # transferring O(numPartitions) objects to Java. + # Each object is a (splitNumber, [objects]) pair. + # In order to avoid too huge objects, the objects are + # grouped into chunks. outputSerializer = self.ctx._unbatched_serializer + limit = (_parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) / 2) + def add_shuffle_key(split, iterator): buckets = defaultdict(list) + c, batch = 0, min(10 * numPartitions, 1000) for (k, v) in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) + c += 1 + + # check used memory and avg size of chunk of objects + if (c % 1000 == 0 and get_used_memory() > limit + or c > batch): + n, size = len(buckets), 0 + for split in buckets.keys(): + yield pack_long(split) + d = outputSerializer.dumps(buckets[split]) + del buckets[split] + yield d + size += len(d) + + avg = (size / n) >> 20 + # let 1M < avg < 10M + if avg < 1: + batch *= 1.5 + elif avg > 10: + batch = max(batch / 1.5, 1) + c = 0 + for (split, items) in buckets.iteritems(): yield pack_long(split) yield outputSerializer.dumps(items) + keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: @@ -1201,8 +1277,8 @@ def add_shuffle_key(split, iterator): id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, even if - # partitionFunc is a lambda: + # This is required so that id(partitionFunc) remains unique, + # even if partitionFunc is a lambda: rdd._partitionFunc = partitionFunc return rdd @@ -1236,26 +1312,28 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, if numPartitions is None: numPartitions = self._defaultReducePartitions() + serializer = self.ctx.serializer + spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() + == 'true') + memory = _parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + def combineLocally(iterator): - combiners = {} - for x in iterator: - (k, v) = x - if k not in combiners: - combiners[k] = createCombiner(v) - else: - combiners[k] = mergeValue(combiners[k], v) - return combiners.iteritems() + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + locally_combined = self.mapPartitions(combineLocally) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - combiners = {} - for (k, v) in iterator: - if k not in combiners: - combiners[k] = v - else: - combiners[k] = mergeCombiners(combiners[k], v) - return combiners.iteritems() + merger = ExternalMerger(agg, memory, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(iterator) + return merger.iteritems() + return shuffled.mapPartitions(_mergeCombiners) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): @@ -1314,7 +1392,8 @@ def mergeValue(xs, x): return xs def mergeCombiners(a, b): - return a + b + a.extend(b) + return a return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions).mapValues(lambda x: ResultIterable(x)) @@ -1377,6 +1456,27 @@ def cogroup(self, other, numPartitions=None): """ return python_cogroup((self, other), numPartitions) + def sampleByKey(self, withReplacement, fractions, seed=None): + """ + Return a subset of this RDD sampled by key (via stratified sampling). + Create a sample of this RDD using variable sampling rates for + different keys as specified by fractions, a key to sampling rate map. + + >>> fractions = {"a": 0.2, "b": 0.1} + >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000))) + >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect()) + >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150 + True + >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0 + True + >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0 + True + """ + for fraction in fractions.values(): + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + return self.mapPartitionsWithIndex( \ + RDDStratifiedSampler(withReplacement, fractions, seed).func, True) + def subtractByKey(self, other, numPartitions=None): """ Return each (key, value) pair in C{self} that has no pair with matching @@ -1587,7 +1687,6 @@ def _jrdd(self): [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() - class_tag = self._prev_jrdd.classTag() env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, @@ -1596,8 +1695,7 @@ def _jrdd(self): bytearray(pickled_command), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, - class_tag) + broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 845a267e311c5..2df000fdb08ca 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -18,18 +18,20 @@ import sys import random -class RDDSampler(object): - def __init__(self, withReplacement, fraction, seed=None): + +class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy self._use_numpy = True except ImportError: - print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." + print >> sys.stderr, ( + "NumPy does not appear to be installed. " + "Falling back to default random generator for sampling.") self._use_numpy = False self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement - self._fraction = fraction self._random = None self._split = None self._rand_initialized = False @@ -61,7 +63,7 @@ def getUniformSample(self, split): def getPoissonSample(self, split, mean): if not self._rand_initialized or split != self._split: self.initRandomGenerator(split) - + if self._use_numpy: return self._random.poisson(mean) else: @@ -80,23 +82,30 @@ def getPoissonSample(self, split, mean): num_arrivals += 1 return (num_arrivals - 1) - + def shuffle(self, vals): - if self._random == None: + if self._random is None: self.initRandomGenerator(0) # this should only ever called on the master so # the split does not matter - + if self._use_numpy: self._random.shuffle(vals) else: self._random.shuffle(vals, self._random.random) + +class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fraction = fraction + def func(self, split, iterator): - if self._withReplacement: + if self._withReplacement: for obj in iterator: - # For large datasets, the expected number of occurrences of each element in a sample with - # replacement is Poisson(frac). We use that to get a count for each element. - count = self.getPoissonSample(split, mean = self._fraction) + # For large datasets, the expected number of occurrences of each element in + # a sample with replacement is Poisson(frac). We use that to get a count for + # each element. + count = self.getPoissonSample(split, mean=self._fraction) for _ in range(0, count): yield obj else: @@ -104,6 +113,21 @@ def func(self, split, iterator): if self.getUniformSample(split) <= self._fraction: yield obj - - - +class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fractions = fractions + + def func(self, split, iterator): + if self._withReplacement: + for key, val in iterator: + # For large datasets, the expected number of occurrences of each element in + # a sample with replacement is Poisson(frac). We use that to get a count for + # each element. + count = self.getPoissonSample(split, mean=self._fractions[key]) + for _ in range(0, count): + yield key, val + else: + for key, val in iterator: + if self.getUniformSample(split) <= self._fractions[key]: + yield key, val diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index 7f418f8d2e29a..df34740fc8176 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -19,6 +19,7 @@ import collections + class ResultIterable(collections.Iterable): """ A special result iterable. This is used because the standard iterator can not be pickled @@ -27,7 +28,9 @@ def __init__(self, data): self.data = data self.index = 0 self.maxindex = len(data) + def __iter__(self): return iter(self.data) + def __len__(self): return len(self.data) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b253807974a2e..03b31ae9624c2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -91,7 +91,6 @@ def load_stream(self, stream): """ raise NotImplementedError - def _load_stream_without_unbatching(self, stream): return self.load_stream(stream) @@ -194,11 +193,11 @@ def load_stream(self, stream): return chain.from_iterable(self._load_stream_without_unbatching(stream)) def _load_stream_without_unbatching(self, stream): - return self.serializer.load_stream(stream) + return self.serializer.load_stream(stream) def __eq__(self, other): - return isinstance(other, BatchedSerializer) and \ - other.serializer == self.serializer + return (isinstance(other, BatchedSerializer) and + other.serializer == self.serializer) def __str__(self): return "BatchedSerializer<%s>" % str(self.serializer) @@ -229,8 +228,8 @@ def load_stream(self, stream): yield pair def __eq__(self, other): - return isinstance(other, CartesianDeserializer) and \ - self.key_ser == other.key_ser and self.val_ser == other.val_ser + return (isinstance(other, CartesianDeserializer) and + self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __str__(self): return "CartesianDeserializer<%s, %s>" % \ @@ -252,18 +251,20 @@ def load_stream(self, stream): yield pair def __eq__(self, other): - return isinstance(other, PairDeserializer) and \ - self.key_ser == other.key_ser and self.val_ser == other.val_ser + return (isinstance(other, PairDeserializer) and + self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __str__(self): - return "PairDeserializer<%s, %s>" % \ - (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): - def loads(self, obj): return obj - def dumps(self, obj): return obj + def loads(self, obj): + return obj + + def dumps(self, obj): + return obj class PickleSerializer(FramedSerializer): @@ -276,12 +277,16 @@ class PickleSerializer(FramedSerializer): not be as fast as more specialized serializers. """ - def dumps(self, obj): return cPickle.dumps(obj, 2) + def dumps(self, obj): + return cPickle.dumps(obj, 2) + loads = cPickle.loads + class CloudPickleSerializer(PickleSerializer): - def dumps(self, obj): return cloudpickle.dumps(obj, 2) + def dumps(self, obj): + return cloudpickle.dumps(obj, 2) class MarshalSerializer(FramedSerializer): @@ -297,6 +302,33 @@ class MarshalSerializer(FramedSerializer): loads = marshal.loads +class AutoSerializer(FramedSerializer): + """ + Choose marshal or cPickle as serialization protocol autumatically + """ + def __init__(self): + FramedSerializer.__init__(self) + self._type = None + + def dumps(self, obj): + if self._type is not None: + return 'P' + cPickle.dumps(obj, -1) + try: + return 'M' + marshal.dumps(obj) + except Exception: + self._type = 'P' + return 'P' + cPickle.dumps(obj, -1) + + def loads(self, obj): + _type = obj[0] + if _type == 'M': + return marshal.loads(obj[1:]) + elif _type == 'P': + return cPickle.loads(obj[1:]) + else: + raise ValueError("invalid sevialization type: %s" % _type) + + class UTF8Deserializer(Serializer): """ Deserializes streams written by String.getBytes. diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index ebd714db7a918..e1e7cd954189f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -35,7 +35,8 @@ from pyspark.storagelevel import StorageLevel # this is the equivalent of ADD_JARS -add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None +add_files = (os.environ.get("ADD_FILES").split(',') + if os.environ.get("ADD_FILES") is not None else None) if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -55,7 +56,7 @@ platform.python_build()[1])) print("SparkContext available as sc.") -if add_files != None: +if add_files is not None: print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py new file mode 100644 index 0000000000000..e3923d1c36c57 --- /dev/null +++ b/python/pyspark/shuffle.py @@ -0,0 +1,439 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import platform +import shutil +import warnings +import gc + +from pyspark.serializers import BatchedSerializer, PickleSerializer + +try: + import psutil + + def get_used_memory(): + """ Return the used memory in MB """ + process = psutil.Process(os.getpid()) + if hasattr(process, "memory_info"): + info = process.memory_info() + else: + info = process.get_memory_info() + return info.rss >> 20 +except ImportError: + + def get_used_memory(): + """ Return the used memory in MB """ + if platform.system() == 'Linux': + for line in open('/proc/self/status'): + if line.startswith('VmRSS:'): + return int(line.split()[1]) >> 10 + else: + warnings.warn("Please install psutil to have better " + "support with spilling") + if platform.system() == "Darwin": + import resource + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + return rss >> 20 + # TODO: support windows + return 0 + + +class Aggregator(object): + + """ + Aggregator has tree functions to merge values into combiner. + + createCombiner: (value) -> combiner + mergeValue: (combine, value) -> combiner + mergeCombiners: (combiner, combiner) -> combiner + """ + + def __init__(self, createCombiner, mergeValue, mergeCombiners): + self.createCombiner = createCombiner + self.mergeValue = mergeValue + self.mergeCombiners = mergeCombiners + + +class SimpleAggregator(Aggregator): + + """ + SimpleAggregator is useful for the cases that combiners have + same type with values + """ + + def __init__(self, combiner): + Aggregator.__init__(self, lambda x: x, combiner, combiner) + + +class Merger(object): + + """ + Merge shuffled data together by aggregator + """ + + def __init__(self, aggregator): + self.agg = aggregator + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + raise NotImplementedError + + def mergeCombiners(self, iterator): + """ Merge the combined items by mergeCombiner """ + raise NotImplementedError + + def iteritems(self): + """ Return the merged items ad iterator """ + raise NotImplementedError + + +class InMemoryMerger(Merger): + + """ + In memory merger based on in-memory dict. + """ + + def __init__(self, aggregator): + Merger.__init__(self, aggregator) + self.data = {} + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + # speed up attributes lookup + d, creator = self.data, self.agg.createCombiner + comb = self.agg.mergeValue + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else creator(v) + + def mergeCombiners(self, iterator): + """ Merge the combined items by mergeCombiner """ + # speed up attributes lookup + d, comb = self.data, self.agg.mergeCombiners + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else v + + def iteritems(self): + """ Return the merged items ad iterator """ + return self.data.iteritems() + + +class ExternalMerger(Merger): + + """ + External merger will dump the aggregated data into disks when + memory usage goes above the limit, then merge them together. + + This class works as follows: + + - It repeatedly combine the items and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. + + - Then it goes through the rest of the iterator, combine items + into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. + + - Before return any items, it will load each partition and + combine them seperately. Yield them before loading next + partition. + + - During loading a partition, if the memory goes over limit, + it will partition the loaded data and dump them into disks + and load them partition by partition again. + + `data` and `pdata` are used to hold the merged items in memory. + At first, all the data are merged into `data`. Once the used + memory goes over limit, the items in `data` are dumped indo + disks, `data` will be cleared, all rest of items will be merged + into `pdata` and then dumped into disks. Before returning, all + the items in `pdata` will be dumped into disks. + + Finally, if any items were spilled into disks, each partition + will be merged into `data` and be yielded, then cleared. + + >>> agg = SimpleAggregator(lambda x, y: x + y) + >>> merger = ExternalMerger(agg, 10) + >>> N = 10000 + >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> assert merger.spills > 0 + >>> sum(v for k,v in merger.iteritems()) + 499950000 + + >>> merger = ExternalMerger(agg, 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> assert merger.spills > 0 + >>> sum(v for k,v in merger.iteritems()) + 499950000 + """ + + # the max total partitions created recursively + MAX_TOTAL_PARTITIONS = 4096 + + def __init__(self, aggregator, memory_limit=512, serializer=None, + localdirs=None, scale=1, partitions=59, batch=1000): + Merger.__init__(self, aggregator) + self.memory_limit = memory_limit + # default serializer is only used for tests + self.serializer = serializer or \ + BatchedSerializer(PickleSerializer(), 1024) + self.localdirs = localdirs or self._get_dirs() + # number of partitions when spill data into disks + self.partitions = partitions + # check the memory after # of items merged + self.batch = batch + # scale is used to scale down the hash of key for recursive hash map + self.scale = scale + # unpartitioned merged data + self.data = {} + # partitioned merged data, list of dicts + self.pdata = [] + # number of chunks dumped into disks + self.spills = 0 + # randomize the hash of key, id(o) is the address of o (aligned by 8) + self._seed = id(self) + 7 + + def _get_dirs(self): + """ Get all the directories """ + path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") + dirs = path.split(",") + return [os.path.join(d, "python", str(os.getpid()), str(id(self))) + for d in dirs] + + def _get_spill_dir(self, n): + """ Choose one directory for spill by number n """ + return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) + + def _next_limit(self): + """ + Return the next memory limit. If the memory is not released + after spilling, it will dump the data only when the used memory + starts to increase. + """ + return max(self.memory_limit, get_used_memory() * 1.05) + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + iterator = iter(iterator) + # speedup attribute lookup + creator, comb = self.agg.createCombiner, self.agg.mergeValue + d, c, batch = self.data, 0, self.batch + + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else creator(v) + + c += 1 + if c % batch == 0 and get_used_memory() > self.memory_limit: + self._spill() + self._partitioned_mergeValues(iterator, self._next_limit()) + break + + def _partition(self, key): + """ Return the partition for key """ + return hash((key, self._seed)) % self.partitions + + def _partitioned_mergeValues(self, iterator, limit=0): + """ Partition the items by key, then combine them """ + # speedup attribute lookup + creator, comb = self.agg.createCombiner, self.agg.mergeValue + c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch + + for k, v in iterator: + d = pdata[hfun(k)] + d[k] = comb(d[k], v) if k in d else creator(v) + if not limit: + continue + + c += 1 + if c % batch == 0 and get_used_memory() > limit: + self._spill() + limit = self._next_limit() + + def mergeCombiners(self, iterator, check=True): + """ Merge (K,V) pair by mergeCombiner """ + iterator = iter(iterator) + # speedup attribute lookup + d, comb, batch = self.data, self.agg.mergeCombiners, self.batch + c = 0 + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else v + if not check: + continue + + c += 1 + if c % batch == 0 and get_used_memory() > self.memory_limit: + self._spill() + self._partitioned_mergeCombiners(iterator, self._next_limit()) + break + + def _partitioned_mergeCombiners(self, iterator, limit=0): + """ Partition the items by key, then merge them """ + comb, pdata = self.agg.mergeCombiners, self.pdata + c, hfun = 0, self._partition + for k, v in iterator: + d = pdata[hfun(k)] + d[k] = comb(d[k], v) if k in d else v + if not limit: + continue + + c += 1 + if c % self.batch == 0 and get_used_memory() > limit: + self._spill() + limit = self._next_limit() + + def _spill(self): + """ + dump already partitioned data into disks. + + It will dump the data in batch for better performance. + """ + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + if not self.pdata: + # The data has not been partitioned, it will iterator the + # dataset once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + for k, v in self.data.iteritems(): + h = self._partition(k) + # put one item in batch, make it compatitable with load_stream + # it will increase the memory if dump them in batch + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + s.close() + + self.data.clear() + self.pdata = [{} for i in range(self.partitions)] + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + + self.spills += 1 + gc.collect() # release the memory as much as possible + + def iteritems(self): + """ Return all merged items as iterator """ + if not self.pdata and not self.spills: + return self.data.iteritems() + return self._external_items() + + def _external_items(self): + """ Return all partitioned items as iterator """ + assert not self.data + if any(self.pdata): + self._spill() + hard_limit = self._next_limit() + + try: + for i in range(self.partitions): + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(i)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), + False) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > hard_limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + for v in self._recursive_merged_items(i): + yield v + return + + for v in self.data.iteritems(): + yield v + self.data.clear() + gc.collect() + + # remove the merged partition + for j in range(self.spills): + path = self._get_spill_dir(j) + os.remove(os.path.join(path, str(i))) + + finally: + self._cleanup() + + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) + + def _recursive_merged_items(self, start): + """ + merge the partitioned items and return the as iterator + + If one partition can not be fit in memory, then them will be + partitioned and merged recursively. + """ + # make sure all the data are dumps into disks. + assert not self.data + if any(self.pdata): + self._spill() + assert self.spills > 0 + + for i in range(start, self.partitions): + subdirs = [os.path.join(d, "parts", str(i)) + for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, + subdirs, self.scale * self.partitions) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(i)) + m._partitioned_mergeCombiners( + self.serializer.load_stream(open(p))) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() + + for v in m._external_items(): + yield v + + # remove the merged partition + for j in range(self.spills): + path = self._get_spill_dir(j) + os.remove(os.path.join(path, str(i))) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ffe177576f363..cb83e89176823 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,7 +30,7 @@ class SQLContext: tables, execute SQL over tables, cache tables, and read parquet files. """ - def __init__(self, sparkContext, sqlContext = None): + def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. @@ -137,7 +137,6 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): """Loads a text file storing one JSON object per line, returning the result as a L{SchemaRDD}. @@ -234,8 +233,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ - "sbt/sbt assembly" , e) + raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + "sbt/sbt assembly", e) def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) @@ -377,7 +376,7 @@ def registerAsTable(self, name): """ self._jschema_rdd.registerAsTable(name) - def insertInto(self, tableName, overwrite = False): + def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. Optionally overwriting any existing data. @@ -420,7 +419,7 @@ def _toPython(self): # in Java land in the javaToPython function. May require a custom # pickle serializer in Pyrolite return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) + PickleSerializer())).map(lambda d: Row(d)) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class @@ -483,6 +482,7 @@ def subtract(self, other, numPartitions=None): else: raise ValueError("Can only subtract another SchemaRDD") + def _test(): import doctest from array import array @@ -493,20 +493,25 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, - {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) - jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] + globs['rdd'] = sc.parallelize( + [{"field1": 1, "field2": "row1"}, + {"field1": 2, "field2": "row2"}, + {"field1": 3, "field2": "row3"}] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) globs['nestedRdd1'] = sc.parallelize([ - {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) + {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, + {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) - (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) + {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, + {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) @@ -514,4 +519,3 @@ def _test(): if __name__ == "__main__": _test() - diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 080325061a697..e287bd3da1f61 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,18 +20,19 @@ import copy import math + class StatCounter(object): - + def __init__(self, values=[]): self.n = 0L # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") self.minValue = float("inf") - + for v in values: self.merge(v) - + # Add a value into this StatCounter, updating the internal statistics. def merge(self, value): delta = value - self.mu @@ -42,7 +43,7 @@ def merge(self, value): self.maxValue = value if self.minValue > value: self.minValue = value - + return self # Merge another StatCounter into this one, adding up the internal statistics. @@ -50,7 +51,7 @@ def mergeStats(self, other): if not isinstance(other, StatCounter): raise Exception("Can only merge Statcounters!") - if other is self: # reference equality holds + if other is self: # reference equality holds self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order else: if self.n == 0: @@ -59,8 +60,8 @@ def mergeStats(self, other): self.n = other.n self.maxValue = other.maxValue self.minValue = other.minValue - - elif other.n != 0: + + elif other.n != 0: delta = other.mu - self.mu if other.n * 10 < self.n: self.mu = self.mu + (delta * other.n) / (self.n + other.n) @@ -68,10 +69,10 @@ def mergeStats(self, other): self.mu = other.mu - (delta * self.n) / (self.n + other.n) else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - + self.maxValue = max(self.maxValue, other.maxValue) self.minValue = min(self.minValue, other.minValue) - + self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n return self @@ -94,7 +95,7 @@ def min(self): def max(self): return self.maxValue - + # Return the variance of the values. def variance(self): if self.n == 0: @@ -124,5 +125,5 @@ def sampleStdev(self): return math.sqrt(self.sampleVariance()) def __repr__(self): - return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min()) - + return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % + (self.count(), self.mean(), self.stdev(), self.max(), self.min())) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 3a18ea54eae4c..5d77a131f2856 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -17,6 +17,7 @@ __all__ = ["StorageLevel"] + class StorageLevel: """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -25,7 +26,7 @@ class StorageLevel: Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. """ - def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication = 1): + def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1): self.useDisk = useDisk self.useMemory = useMemory self.useOffHeap = useOffHeap @@ -55,4 +56,4 @@ def __str__(self): StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) -StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) \ No newline at end of file +StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c15bb457759ed..63cc5e9ad96fa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,6 +34,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False try: @@ -47,17 +48,74 @@ SPARK_HOME = os.environ["SPARK_HOME"] +class TestMerger(unittest.TestCase): + + def setUp(self): + self.N = 1 << 16 + self.l = [i for i in xrange(self.N)] + self.data = zip(self.l, self.l) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_in_memory(self): + m = InMemoryMerger(self.agg) + m.mergeValues(self.data) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = InMemoryMerger(self.agg) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 10) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.N * 10) + m._cleanup() + + class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name , batchSize=2) + self.sc = SparkContext('local[4]', class_name, batchSize=2) def tearDown(self): self.sc.stop() sys.path = self._old_sys_path + class TestCheckpoint(PySparkTestCase): def setUp(self): @@ -151,6 +209,12 @@ def func(): class TestRDDFunctions(PySparkTestCase): + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.sc.stop() + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + self.sc = SparkContext("local") + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -168,6 +232,15 @@ def test_transforming_cartesian_result(self): cart = rdd1.cartesian(rdd2) result = cart.map(lambda (x, y): x + y).collect() + def test_transforming_pickle_file(self): + # Regression test for SPARK-2601 + data = self.sc.parallelize(["Hello", "World!"]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsPickleFile(tempFile.name) + pickled_file = self.sc.pickleFile(tempFile.name) + pickled_file.map(lambda x: x).collect() + def test_cartesian_on_textfile(self): # Regression test for path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") @@ -190,6 +263,7 @@ def test_deleting_input_files(self): def testAggregateByKey(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + def seqOp(x, y): x.add(y) return x @@ -197,17 +271,19 @@ def seqOp(x, y): def combOp(x, y): x |= y return x - + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) self.assertEqual(3, len(sets)) self.assertEqual(set([1]), sets[1]) self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + class TestIO(PySparkTestCase): def test_stdout_redirection(self): import subprocess + def func(x): subprocess.check_call('ls', shell=True) self.sc.parallelize([1]).foreach(func) @@ -479,7 +555,7 @@ def test_module_dependency(self): | return x + 1 """) proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], - stdout=subprocess.PIPE) + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f43210c6c0301..24d41b12d1b1a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,8 +57,8 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) + sys.path.append(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) diff --git a/python/run-tests b/python/run-tests index 9282aa47e8375..29f755fc0dcd3 100755 --- a/python/run-tests +++ b/python/run-tests @@ -61,6 +61,7 @@ run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" unset PYSPARK_DOC_TEST +run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index bce5c74b9d0da..9099e052f5796 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -197,7 +197,7 @@ trait SparkImports { for (imv <- x.definedNames) { if (currentImps contains imv) addWrapper() val objName = req.lineRep.readPath - val valName = "$VAL" + req.lineRep.lineId + val valName = "$VAL" + newValId() if(!code.toString.endsWith(".`" + imv + "`;\n")) { // Which means already imported code.append("val " + valName + " = " + objName + ".INSTANCE;\n") @@ -222,4 +222,10 @@ trait SparkImports { private def membersAtPickler(sym: Symbol): List[Symbol] = beforePickler(sym.info.nonPrivateMembers.toList) + private var curValId = 0 + + private def newValId(): Int = { + curValId += 1 + curValId + } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f2aa42dbcb4fc..e2d8d5ff38dbe 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -235,7 +235,7 @@ class ReplSuite extends FunSuite { assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) } - test("SPARK-1199-simple-reproduce") { + test("SPARK-1199 two instances of same class don't type check.") { val output = runInterpreter("local-cluster[1,1,512]", """ |case class Sum(exp: String, exp2: String) @@ -247,6 +247,15 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-2452 compound statements.") { + val output = runInterpreter("local", + """ + |val x = 4 ; def f() = x + |f() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e5653c5b14ac1..a34b236c8ac6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -120,7 +120,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val WHERE = Keyword("WHERE") protected val INTERSECT = Keyword("INTERSECT") protected val EXCEPT = Keyword("EXCEPT") - + protected val SUBSTR = Keyword("SUBSTR") + protected val SUBSTRING = Keyword("SUBSTRING") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -316,6 +317,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | + (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) + } | + (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) + } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c7188469bfb86..02bdb64f308a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ - /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing * when all relations are already filled in and the analyser needs only to resolve attribute @@ -54,6 +53,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: + UnresolvedHavingClauseAttributes :: typeCoercionRules :_*), Batch("Check Analysis", Once, CheckResolution), @@ -151,6 +151,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * This rule finds expressions in HAVING clause filters that depend on + * unresolved attributes. It pushes these expressions down to the underlying + * aggregates and then projects them away above the filter. + */ + object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + + Project(aggregate.output, + Filter(evaluatedCondition.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } + + } + + protected def containsAggregate(condition: Expression): Boolean = + condition + .collect { case ae: AggregateExpression => ae } + .nonEmpty + } + /** * When a SELECT clause has only a single expression and that expression is a * [[catalyst.expressions.Generator Generator]] we convert the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 0d05d9808b407..616f1e2ecb60f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -113,11 +113,12 @@ trait OverrideCatalog extends Catalog { alias: Option[String] = None): LogicalPlan = { val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) val overriddenTable = overrides.get((dbName, tblName)) + val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. val withAlias = - overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) + tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 76ddeba9cb312..47c7ad076ad07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -50,6 +50,7 @@ trait HiveTypeCoercion { StringToIntegralCasts :: FunctionArgumentConversion :: CastNulls :: + Division :: Nil /** @@ -231,11 +232,23 @@ trait HiveTypeCoercion { * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ object BooleanComparisons extends Rule[LogicalPlan] { + val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_)) + val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_)) + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change EqualTo operators as that actually makes sense for boolean types. + + // Hive treats (true = 1) as true and (false = 0) as true. + case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l + case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r + case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) + case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) + + // No need to change other EqualTo operators as that actually makes sense for boolean types. case e: EqualTo => e + // No need to change the EqualNullSafe operators, too + case e: EqualNullSafe => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => @@ -305,6 +318,23 @@ trait HiveTypeCoercion { } } + /** + * Hive only performs integral division with the DIV operator. The arguments to / are always + * converted to fractional types. + */ + object Division extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + // Decimal and Double remain the same + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType == DecimalType => d + + case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) + } + } + /** * Ensures that NullType gets casted to some other types under certain circumstances. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1b503b957d146..5c8c810d9135a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -77,10 +77,27 @@ package object dsl { def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) def === (other: Expression) = EqualTo(expr, other) + def <=> (other: Expression) = EqualNullSafe(expr, other) def !== (other: Expression) = Not(EqualTo(expr, other)) + def in(list: Expression*) = In(expr, list) + def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def contains(other: Expression) = Contains(expr, other) + def startsWith(other: Expression) = StartsWith(expr, other) + def endsWith(other: Expression) = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + + def isNull = IsNull(expr) + def isNotNull = IsNotNull(expr) + + def getItem(ordinal: Expression) = GetItem(expr, ordinal) + def getField(fieldName: String) = GetField(expr, fieldName) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) @@ -112,6 +129,7 @@ package object dsl { def sumDistinct(e: Expression) = SumDistinct(e) def count(e: Expression) = Count(e) def countDistinct(e: Expression*) = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) def avg(e: Expression) = Average(e) def first(e: Expression) = First(e) def min(e: Expression) = Min(e) @@ -163,6 +181,18 @@ package object dsl { /** Creates a new AttributeReference of type binary */ def binary = AttributeReference(s, BinaryType, nullable = true)() + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): AttributeReference = + map(MapType(keyType, valueType)) + def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) + def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1f9716e385e9e..0ad2b30cf9c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.sql.catalyst.types._ @@ -41,6 +42,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // UDFToString private[this] def castToString: Any => Any = child.dataType match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -126,6 +128,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 } + // Converts Timestamp to string according to Hive TimestampWritable convention + private[this] def timestampToString(ts: Timestamp): String = { + val timestampString = ts.toString + val formatted = Cast.threadLocalDateFormat.get.format(ts) + + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } + private[this] def castToLong: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toLong catch { @@ -249,3 +263,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { if (evaluated == null) null else cast(evaluated) } } + +object Cast { + // `SimpleDateFormat` is not thread-safe. + private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue() = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b63406b94a4a3..06b94a98d3cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -153,6 +153,22 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } +case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<=>" + override def nullable = false + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } + } +} + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b3850533c3736..97fc3a3b14b88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import org.apache.spark.sql.catalyst.types.DataType -import org.apache.spark.sql.catalyst.types.StringType -import org.apache.spark.sql.catalyst.types.BooleanType +import scala.collection.IndexedSeqOptimized + + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType} trait StringRegexExpression { self: BinaryExpression => @@ -205,3 +207,74 @@ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { def compare(l: String, r: String) = l.endsWith(r) } + +/** + * A function that takes a substring of its first argument starting at a given position. + * Defined for String and Binary types. + */ +case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression { + + type EvaluatedType = Any + + override def foldable = str.foldable && pos.foldable && len.foldable + + def nullable: Boolean = str.nullable || pos.nullable || len.nullable + def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") + } + if (str.dataType == BinaryType) str.dataType else StringType + } + + def references = children.flatMap(_.references).toSet + + override def children = str :: pos :: len :: Nil + + @inline + def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) + (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { + val len = str.length + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + + val start = startPos match { + case pos if pos > 0 => pos - 1 + case neg if neg < 0 => len + neg + case _ => 0 + } + + val end = sliceLen match { + case max if max == Integer.MAX_VALUE => max + case x => start + x + } + + str.slice(start, end) + } + + override def eval(input: Row): Any = { + val string = str.eval(input) + + val po = pos.eval(input) + val ln = len.eval(input) + + if ((string == null) || (po == null) || (ln == null)) { + null + } else { + val start = po.asInstanceOf[Int] + val length = ln.asInstanceOf[Int] + + string match { + case ba: Array[Byte] => slice(ba, start, length) + case other => slice(other.toString, start, length) + } + } + } + + override def toString = len match { + case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" + case _ => s"SUBSTR($str, $pos, $len)" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a142310c501b0..5f86d6047cb9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,11 +153,15 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) - case e @ Coalesce(children) => { - val newChildren = children.filter(c => c match { + case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) + case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) + + // For Coalesce, remove null literals. + case e @ Coalesce(children) => + val newChildren = children.filter { case Literal(null, _) => false case _ => true - }) + } if (newChildren.length == 0) { Literal(null, e.dataType) } else if (newChildren.length == 1) { @@ -165,12 +169,11 @@ object NullPropagation extends Rule[LogicalPlan] { } else { Coalesce(newChildren) } - } - case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue - case e @ In(Literal(v, _), list) if (list.exists(c => c match { - case Literal(candidate, _) if candidate == v => true - case _ => false - })) => Literal(true, BooleanType) + + case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -187,6 +190,11 @@ object NullPropagation extends Rule[LogicalPlan] { case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) case _ => e } + case e: StringComparison => e.children match { + case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case _ => e + } } } } @@ -198,9 +206,19 @@ object NullPropagation extends Rule[LogicalPlan] { object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - // Skip redundant folding of literals. + // Skip redundant folding of literals. This rule is technically not necessary. Placing this + // here avoids running the next rule for Literal values, which would create a new Literal + // object and running eval unnecessarily. case l: Literal => l + + // Fold expressions that are foldable. case e if e.foldable => Literal(e.eval(null), e.dataType) + + // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. + case In(Literal(v, _), list) if list.exists { + case Literal(candidate, _) if candidate == v => true + case _ => false + } => Literal(true, BooleanType) } } } @@ -230,6 +248,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(false, BooleanType)) => l case (_, _) => or } + + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". + case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } } } @@ -251,12 +272,12 @@ object CombineFilters extends Rule[LogicalPlan] { */ object SimplifyFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Literal(true, BooleanType), child) => - child - case Filter(Literal(null, _), child) => - LocalRelation(child.output) - case Filter(Literal(false, BooleanType), child) => - LocalRelation(child.output) + // If the filter condition always evaluate to true, remove the filter. + case Filter(Literal(true, BooleanType), child) => child + // If the filter condition always evaluate to null or false, + // replace the input with an empty relation. + case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty) + case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty) } } @@ -298,7 +319,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. - * @returns (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) + * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index bb77bccf86176..cd4b5e9c1b529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,17 +19,20 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.util.parsing.combinator.RegexParsers - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. */ +protected[catalyst] object ScalaReflectionLock + object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -62,7 +65,6 @@ object DataType extends RegexParsers { "true" ^^^ true | "false" ^^^ false - protected lazy val structType: Parser[DataType] = "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { case fields => new StructType(fields) @@ -106,7 +108,7 @@ abstract class NativeType extends DataType { @transient val tag: TypeTag[JvmType] val ordering: Ordering[JvmType] - @transient val classTag = { + @transient val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } @@ -114,22 +116,24 @@ abstract class NativeType extends DataType { case object StringType extends NativeType with PrimitiveType { type JvmType = String - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = implicitly[Ordering[JvmType]] } + case object BinaryType extends DataType with PrimitiveType { type JvmType = Array[Byte] } + case object BooleanType extends NativeType with PrimitiveType { type JvmType = Boolean - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = implicitly[Ordering[JvmType]] } case object TimestampType extends NativeType { type JvmType = Timestamp - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) @@ -159,7 +163,7 @@ abstract class IntegralType extends NumericType { case object LongType extends IntegralType { type JvmType = Long - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Long]] val integral = implicitly[Integral[Long]] val ordering = implicitly[Ordering[JvmType]] @@ -167,7 +171,7 @@ case object LongType extends IntegralType { case object IntegerType extends IntegralType { type JvmType = Int - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Int]] val integral = implicitly[Integral[Int]] val ordering = implicitly[Ordering[JvmType]] @@ -175,7 +179,7 @@ case object IntegerType extends IntegralType { case object ShortType extends IntegralType { type JvmType = Short - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Short]] val integral = implicitly[Integral[Short]] val ordering = implicitly[Ordering[JvmType]] @@ -183,7 +187,7 @@ case object ShortType extends IntegralType { case object ByteType extends IntegralType { type JvmType = Byte - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Byte]] val integral = implicitly[Integral[Byte]] val ordering = implicitly[Ordering[JvmType]] @@ -202,7 +206,7 @@ abstract class FractionalType extends NumericType { case object DecimalType extends FractionalType { type JvmType = BigDecimal - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[BigDecimal]] val fractional = implicitly[Fractional[BigDecimal]] val ordering = implicitly[Ordering[JvmType]] @@ -210,7 +214,7 @@ case object DecimalType extends FractionalType { case object DoubleType extends FractionalType { type JvmType = Double - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Double]] val fractional = implicitly[Fractional[Double]] val ordering = implicitly[Ordering[JvmType]] @@ -218,7 +222,7 @@ case object DoubleType extends FractionalType { case object FloatType extends FractionalType { type JvmType = Float - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Float]] val fractional = implicitly[Fractional[Float]] val ordering = implicitly[Ordering[JvmType]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 84d72814778ba..58f8c341e6676 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -158,7 +158,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - + checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) } @@ -203,7 +203,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("data type casting") { - val sts = "1970-01-01 00:00:01.0" + val sts = "1970-01-01 00:00:01.1" val ts = Timestamp.valueOf(sts) checkEvaluation("abdef" cast StringType, "abdef") @@ -293,7 +293,7 @@ class ExpressionEvaluationSuite extends FunSuite { // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) } - + test("null checking") { val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) val c1 = 'a.string.at(0) @@ -301,18 +301,18 @@ class ExpressionEvaluationSuite extends FunSuite { val c3 = 'a.boolean.at(2) val c4 = 'a.boolean.at(3) - checkEvaluation(IsNull(c1), false, row) - checkEvaluation(IsNotNull(c1), true, row) + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) - checkEvaluation(IsNull(c2), true, row) - checkEvaluation(IsNotNull(c2), false, row) + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(IsNull(Literal(1, ShortType)), false) - checkEvaluation(IsNotNull(Literal(1, ShortType)), true) + checkEvaluation(Literal(1, ShortType).isNull, false) + checkEvaluation(Literal(1, ShortType).isNotNull, true) + + checkEvaluation(Literal(null, ShortType).isNull, true) + checkEvaluation(Literal(null, ShortType).isNotNull, false) - checkEvaluation(IsNull(Literal(null, ShortType)), true) - checkEvaluation(IsNotNull(Literal(null, ShortType)), false) - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) @@ -323,14 +323,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), + checkEvaluation(If(Literal(false, BooleanType), Literal("a", StringType), Literal("b", StringType)), "b", row) - checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -378,7 +378,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("complex type") { val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 + "^Ba*n", // 0 null.asInstanceOf[String], // 1 new GenericRow(Array[Any]("aa", "bb")), // 2 Map("aa"->"bb"), // 3 @@ -391,18 +391,18 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), Literal(null, IntegerType)), null, row) checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) @@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite { assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + + checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) + checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) + checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) } test("arithmetic") { @@ -447,11 +451,13 @@ class ExpressionEvaluationSuite extends FunSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) val c4 = 'a.int.at(3) + val c5 = 'a.int.at(4) + val c6 = 'a.int.at(5) checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) @@ -465,6 +471,93 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 >= c2, false, row) checkEvaluation(c1 === c2, false, row) checkEvaluation(c1 !== c2, true, row) + checkEvaluation(c4 <=> c1, false, row) + checkEvaluation(c1 <=> c4, false, row) + checkEvaluation(c4 <=> c6, true, row) + checkEvaluation(c3 <=> c5, true, row) + checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) + checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) } -} + test("StringComparison") { + val row = new GenericRow(Array[Any]("abc", null)) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal(null, StringType), null, row) + + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + } + + test("Substring") { + val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + + val s = 'a.string.at(0) + + // substring from zero position with less-than-full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + + // substring from zero position with full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + + // substring from zero position with greater-than-full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + + // substring from nonzero position with less-than-full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + + // substring from nonzero position with full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + + // substring from nonzero position with greater-than-full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + + // zero-length substring (within string bounds) + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + + // zero-length substring (beyond string bounds) + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + + // substring(null, _, _) -> null + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + + // substring(_, null, _) -> null + checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + + // substring(_, _, null) -> null + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + + // 2-arg substring from zero position + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + + // 2-arg substring from nonzero position + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + + val s_notNull = 'a.string.notNull.at(0) + + assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) + assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 0ff82064012a8..0a27cce337482 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(10) as Symbol("2*3+4"), Literal(14) as Symbol("2*(3+4)")) .where(Literal(true)) - .groupBy(Literal(3))(Literal(3) as Symbol("9/3")) + .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3")) .analyze comparePlans(optimized, correctAnswer) @@ -201,7 +201,15 @@ class ConstantFoldingSuite extends PlanTest { Like(Literal(null, StringType), "abc") as 'c13, Like("abc", Literal(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15) + Upper(Literal(null, StringType)) as 'c15, + + Substring(Literal(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + + Contains(Literal(null, StringType), "abc") as 'c19, + Contains("abc", Literal(null, StringType)) as 'c20 + ) val optimized = Optimize(originalQuery.analyze) @@ -228,8 +236,15 @@ class ConstantFoldingSuite extends PlanTest { Literal(null, BooleanType) as 'c13, Literal(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15) - .analyze + Literal(null, StringType) as 'c15, + + Literal(null, StringType) as 'c16, + Literal(null, StringType) as 'c17, + Literal(null, StringType) as 'c18, + + Literal(null, BooleanType) as 'c19, + Literal(null, BooleanType) as 'c20 + ).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala new file mode 100644 index 0000000000000..b10577c8001e2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class LikeSimplificationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Like Simplification", Once, + LikeSimplification) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify Like into StartsWith") { + val originalQuery = + testRelation + .where(('a like "abc%") || ('a like "abc\\%")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(StartsWith('a, "abc") || ('a like "abc\\%")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into EndsWith") { + val originalQuery = + testRelation + .where('a like "%xyz") + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(EndsWith('a, "xyz")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into Contains") { + val originalQuery = + testRelation + .where(('a like "%mn%") || ('a like "%mn\\%")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(Contains('a, "mn") || ('a like "%mn\\%")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into EqualTo") { + val originalQuery = + testRelation + .where(('a like "") || ('a like "abc")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a === "") || ('a === "abc")) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 0c95b668545f4..31d27bb4f0571 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -380,32 +380,32 @@ class SchemaRDD( val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (name, dataType)) => + case (obj, (attrName, dataType)) => dataType match { - case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) case array @ ArrayType(struct: StructType) => val arrayValues = obj match { case seq: Seq[Any] => seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[Any] => + case list: JList[_] => list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[Any] => + case set: JSet[_] => set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case array if array != null && array.getClass.isArray => - array.asInstanceOf[Array[Any]].map { + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map { element => rowToMap(element.asInstanceOf[Row], struct) } case other => other } - map.put(name, arrayValues) + map.put(attrName, arrayValues) case array: ArrayType => { val arrayValues = obj match { case seq: Seq[Any] => seq.asJava case other => other } - map.put(name, arrayValues) + map.put(attrName, arrayValues) } - case other => map.put(name, obj) + case other => map.put(attrName, obj) } } @@ -430,7 +430,7 @@ class SchemaRDD( * @group schema */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(logicalPlan.output, rdd))) + new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))) } // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index aff6ffe9f3478..8fbf13b8b0150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import java.util.{List => JList} + import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} @@ -96,6 +98,20 @@ class JavaSchemaRDD( this } + // Overridden actions from JavaRDDLike. + + override def collect(): JList[Row] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[Row] = baseSchemaRDD.collect().toSeq.map(new Row(_)) + new java.util.ArrayList(arr) + } + + override def take(num: Int): JList[Row] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_)) + new java.util.ArrayList(arr) + } + // Transformations (return a new RDD) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 3c39e1d350fa8..42a5a9a84f362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -90,6 +90,9 @@ private[sql] class FloatColumnAccessor(buffer: ByteBuffer) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor @@ -105,16 +108,17 @@ private[sql] object ColumnAccessor { val columnTypeId = dup.getInt() columnTypeId match { - case INT.typeId => new IntColumnAccessor(dup) - case LONG.typeId => new LongColumnAccessor(dup) - case FLOAT.typeId => new FloatColumnAccessor(dup) - case DOUBLE.typeId => new DoubleColumnAccessor(dup) - case BOOLEAN.typeId => new BooleanColumnAccessor(dup) - case BYTE.typeId => new ByteColumnAccessor(dup) - case SHORT.typeId => new ShortColumnAccessor(dup) - case STRING.typeId => new StringColumnAccessor(dup) - case BINARY.typeId => new BinaryColumnAccessor(dup) - case GENERIC.typeId => new GenericColumnAccessor(dup) + case INT.typeId => new IntColumnAccessor(dup) + case LONG.typeId => new LongColumnAccessor(dup) + case FLOAT.typeId => new FloatColumnAccessor(dup) + case DOUBLE.typeId => new DoubleColumnAccessor(dup) + case BOOLEAN.typeId => new BooleanColumnAccessor(dup) + case BYTE.typeId => new ByteColumnAccessor(dup) + case SHORT.typeId => new ShortColumnAccessor(dup) + case STRING.typeId => new StringColumnAccessor(dup) + case TIMESTAMP.typeId => new TimestampColumnAccessor(dup) + case BINARY.typeId => new BinaryColumnAccessor(dup) + case GENERIC.typeId => new GenericColumnAccessor(dup) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 4be048cd742d6..74f5630fbddf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -109,6 +109,9 @@ private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColum private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[sql] class TimestampColumnBuilder + extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) + private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) // TODO (lian) Add support for array, struct and map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 95602d321dc6f..6502110e903fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -344,21 +344,52 @@ private[sql] class StringColumnStats extends BasicColumnStats(STRING) { } override def contains(row: Row, ordinal: Int) = { - !(upperBound eq null) && { + (upperBound ne null) && { val field = columnType.getField(row, ordinal) lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 } } override def isAbove(row: Row, ordinal: Int) = { - !(upperBound eq null) && { + (upperBound ne null) && { val field = columnType.getField(row, ordinal) field.compareTo(upperBound) < 0 } } override def isBelow(row: Row, ordinal: Int) = { - !(lowerBound eq null) && { + (lowerBound ne null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) < 0 + } + } +} + +private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) { + override def initialBounds = (null, null) + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field + if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field + } + + override def contains(row: Row, ordinal: Int) = { + (upperBound ne null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + } + } + + override def isAbove(row: Row, ordinal: Int) = { + (lowerBound ne null) && { + val field = columnType.getField(row, ordinal) + field.compareTo(upperBound) < 0 + } + } + + override def isBelow(row: Row, ordinal: Int) = { + (lowerBound ne null) && { val field = columnType.getField(row, ordinal) lowerBound.compareTo(field) < 0 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 4cd52d8288137..794bc60d0e315 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -21,6 +21,8 @@ import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag +import java.sql.Timestamp + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ @@ -221,6 +223,26 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def getField(row: Row, ordinal: Int) = row.getString(ordinal) } +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { + override def extract(buffer: ByteBuffer) = { + val timestamp = new Timestamp(buffer.getLong()) + timestamp.setNanos(buffer.getInt()) + timestamp + } + + override def append(v: Timestamp, buffer: ByteBuffer) { + buffer.putLong(v.getTime).putInt(v.getNanos) + } + + override def getField(row: Row, ordinal: Int) = { + row(ordinal).asInstanceOf[Timestamp] + } + + override def setField(row: MutableRow, ordinal: Int, value: Timestamp) { + row(ordinal) = value + } +} + private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( typeId: Int, defaultSize: Int) @@ -240,7 +262,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) { +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { row(ordinal) = value } @@ -251,7 +273,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) { // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) { +private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } @@ -262,16 +284,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { - case IntegerType => INT - case LongType => LONG - case FloatType => FLOAT - case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT - case StringType => STRING - case BinaryType => BINARY - case _ => GENERIC + case IntegerType => INT + case LongType => LONG + case FloatType => FLOAT + case DoubleType => DOUBLE + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT + case StringType => STRING + case BinaryType => BINARY + case TimestampType => TIMESTAMP + case _ => GENERIC } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index ff7f664d8b529..88901debbb4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -96,7 +96,11 @@ private[sql] case class InMemoryColumnarTableScan( new Iterator[Row] { // Find the ordinals of the requested columns. If none are requested, use the first. val requestedColumns = - if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_)) + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) val nextRow = new GenericMutableRow(columnAccessors.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 34b355e906695..34654447a5f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -24,10 +24,10 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.AllScalaRegistrar +import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -48,22 +48,41 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } } -private[sql] object SparkSqlSerializer { - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - @transient lazy val ser: KryoSerializer = { +private[execution] class KryoResourcePool(size: Int) + extends ResourcePool[SerializerInstance](size) { + + val ser: KryoSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + // TODO (lian) Using KryoSerializer here is workaround, needs further investigation + // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization + // related error. new KryoSerializer(sparkConf) } - def serialize[T: ClassTag](o: T): Array[Byte] = { - ser.newInstance().serialize(o).array() - } + def newInstance() = ser.newInstance() +} - def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) +private[sql] object SparkSqlSerializer { + @transient lazy val resourcePool = new KryoResourcePool(30) + + private[this] def acquireRelease[O](fn: SerializerInstance => O): O = { + val kryo = resourcePool.borrow + try { + fn(kryo) + } finally { + resourcePool.release(kryo) + } } + + def serialize[T: ClassTag](o: T): Array[Byte] = + acquireRelease { k => + k.serialize(o).array() + } + + def deserialize[T: ClassTag](bytes: Array[Byte]): T = + acquireRelease { k => + k.deserialize[T](ByteBuffer.wrap(bytes)) + } } private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f6cbca96483e2..b48c70ee73a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.json -import scala.collection.JavaConversions._ +import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import com.fasterxml.jackson.databind.ObjectMapper @@ -204,18 +204,18 @@ private[sql] object JsonRDD extends Logging { case (key, value) => (s"`$key`", value) }.toSet keyValuePairs.flatMap { - case (key: String, struct: Map[String, Any]) => { - // The value associted with the key is an JSON object. - allKeysWithValueTypes(struct).map { + case (key: String, struct: Map[_, _]) => { + // The value associated with the key is an JSON object. + allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { case (k, dataType) => (s"$key.$k", dataType) } ++ Set((key, StructType(Nil))) } - case (key: String, array: List[Any]) => { - // The value associted with the key is an array. + case (key: String, array: Seq[_]) => { + // The value associated with the key is an array. typeOfArray(array) match { case ArrayType(StructType(Nil)) => { // The elements of this arrays are structs. - array.asInstanceOf[List[Map[String, Any]]].flatMap { + array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) @@ -229,19 +229,19 @@ private[sql] object JsonRDD extends Logging { } /** - * Converts a Java Map/List to a Scala Map/List. + * Converts a Java Map/List to a Scala Map/Seq. * We do not use Jackson's scala module at here because * DefaultScalaModule in jackson-module-scala will make * the parsing very slow. */ private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[String, Object] => + case map: java.util.Map[_, _] => // .map(identity) is used as a workaround of non-serializable Map // generated by .mapValues. // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - map.toMap.mapValues(scalafy).map(identity) - case list: java.util.List[Object] => - list.toList.map(scalafy) + JMapWrapper(map).mapValues(scalafy).map(identity) + case list: java.util.List[_] => + JListWrapper(list).map(scalafy) case atom => atom } @@ -320,8 +320,8 @@ private[sql] object JsonRDD extends Logging { private def toString(value: Any): String = { value match { - case value: Map[String, Any] => toJsonObjectString(value) - case value: Seq[Any] => toJsonArrayString(value) + case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) + case value: Seq[_] => toJsonArrayString(value) case value => Option(value).map(_.toString).orNull } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ade823b51c9cd..ea74320d06c86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,27 +17,34 @@ package org.apache.spark.sql.parquet +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try + import java.io.IOException +import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, List => JList} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat} -import parquet.hadoop.api.ReadSupport +import parquet.hadoop._ +import parquet.hadoop.api.{InitContext, ReadSupport} +import parquet.hadoop.metadata.GlobalMetaData import parquet.hadoop.util.ContextUtil -import parquet.io.InvalidRecordException +import parquet.io.ParquetDecodingException import parquet.schema.MessageType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** * Parquet table scan operator. Imports the file that backs the given @@ -55,16 +62,14 @@ case class ParquetTableScan( override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass( - job, - classOf[org.apache.spark.sql.parquet.RowReadSupport]) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val conf: Configuration = ContextUtil.getConfiguration(job) - val fileList = FileSystemHelper.listFiles(relation.path, conf) - // add all paths in the directory but skip "hidden" ones such - // as "_SUCCESS" and "_metadata" - for (path <- fileList if !path.getName.startsWith("_")) { - NewFileInputFormat.addInputPath(job, path) + val qualifiedPath = { + val path = new Path(relation.path) + path.getFileSystem(conf).makeQualified(path) } + NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( @@ -87,7 +92,7 @@ case class ParquetTableScan( sc.newAPIHadoopRDD( conf, - classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], + classOf[FilteringParquetRowInputFormat], classOf[Void], classOf[Row]) .map(_._2) @@ -122,14 +127,7 @@ case class ParquetTableScan( private def validateProjection(projection: Seq[Attribute]): Boolean = { val original: MessageType = relation.parquetSchema val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - try { - original.checkContains(candidate) - true - } catch { - case e: InvalidRecordException => { - false - } - } + Try(original.checkContains(candidate)).isSuccess } } @@ -302,6 +300,11 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) */ private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + + private var footers: JList[Footer] = _ + + private var fileStatuses= Map.empty[Path, FileStatus] + override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { @@ -318,6 +321,70 @@ private[parquet] class FilteringParquetRowInputFormat new ParquetRecordReader[Row](readSupport) } } + + override def getFooters(jobContext: JobContext): JList[Footer] = { + if (footers eq null) { + val statuses = listStatus(jobContext) + fileStatuses = statuses.map(file => file.getPath -> file).toMap + footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + } + + footers + } + + // TODO Remove this method and related code once PARQUET-16 is fixed + // This method together with the `getFooters` method and the `fileStatuses` field are just used + // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17 + override def getSplits( + configuration: Configuration, + footers: JList[Footer]): JList[ParquetInputSplit] = { + + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) + val minSplitSize: JLong = + Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) + if (maxSplitSize < 0 || minSplitSize < 0) { + throw new ParquetDecodingException( + s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + + s" minSplitSize = $minSplitSize") + } + + val getGlobalMetaData = + classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) + getGlobalMetaData.setAccessible(true) + val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + + val readContext = getReadSupport(configuration).init( + new InitContext(configuration, + globalMetaData.getKeyValueMetaData(), + globalMetaData.getSchema())) + + val generateSplits = + classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get + generateSplits.setAccessible(true) + + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + for (footer <- footers) { + val fs = footer.getFile.getFileSystem(configuration) + val file = footer.getFile + val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val parquetMetaData = footer.getParquetMetadata + val blocks = parquetMetaData.getBlocks + val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + splits.addAll( + generateSplits.invoke( + null, + blocks, + fileBlockLocations, + fileStatus, + parquetMetaData.getFileMetaData, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + } + + splits + } } private[parquet] object FileSystemHelper { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index f1953a008a49b..39294a3f4bf5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.parquet -import org.apache.hadoop.conf.Configuration +import java.util.{HashMap => JHashMap} +import org.apache.hadoop.conf.Configuration import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.api.{ReadSupport, WriteSupport} import parquet.io.api._ -import parquet.schema.{MessageType, MessageTypeParser} +import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import com.google.common.io.BaseEncoding /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -93,8 +92,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { configuration: Configuration, keyValueMetaData: java.util.Map[String, String], fileSchema: MessageType): ReadContext = { - var parquetSchema: MessageType = fileSchema - var metadata: java.util.Map[String, String] = new java.util.HashMap[String, String]() + var parquetSchema = fileSchema + val metadata = new JHashMap[String, String]() val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) if (requestedAttributes != null) { @@ -109,7 +108,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) } - return new ReadSupport.ReadContext(parquetSchema, metadata) + new ReadSupport.ReadContext(parquetSchema, metadata) } } @@ -132,13 +131,17 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] var attributes: Seq[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { - attributes = if (attributes == null) RowWriteSupport.getSchema(configuration) else attributes - + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + val metadata = new JHashMap[String, String]() + metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + + if (attributes == null) { + attributes = ParquetTypesConverter.convertFromString(origAttributesStr) + } + log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() - new WriteSupport.WriteContext( - ParquetTypesConverter.convertFromAttributes(attributes), - new java.util.HashMap[java.lang.String, java.lang.String]()) + new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 7f6ad908f78ed..58370b955a5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} @@ -367,20 +368,24 @@ private[parquet] object ParquetTypesConverter extends Logging { s"Expected $path for be a directory with Parquet files/metadata") } ParquetRelation.enableLogForwarding() - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - // if this is a new table that was just created we will find only the metadata file - if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - ParquetFileReader.readFooter(conf, metadataPath) - } else { - // there may be one or more Parquet files in the given directory - val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) - // TODO: for now we assume that all footers (if there is more than one) have identical - // metadata; we may want to add a check here at some point - if (footers.size() == 0) { - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") - } - footers(0).getParquetMetadata + + val children = fs.listStatus(path).filterNot { + _.getPath.getName == FileOutputCommitter.SUCCEEDED_FILE_NAME } + + // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row + // groups. Since Parquet schema is replicated among all row groups, we only need to touch a + // single row group to read schema related metadata. Notice that we are making assumptions that + // all data in a single Parquet file have the same schema, which is normally true. + children + // Try any non-"_metadata" file first... + .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) + // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is + // empty, thus normally the "_metadata" file is expected to be fairly small). + .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) + .map(ParquetFileReader.readFooter(conf, _)) + .getOrElse( + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } /** @@ -403,7 +408,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { val attributes = convertToAttributes( readMetaData(origPath, conf).getFileMetaData.getSchema) - log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes") + log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 68dae58728a2a..1a6a6c17473a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test._ /* Implicits */ @@ -33,17 +32,23 @@ class DslQuerySuite extends QueryTest { testData.collect().toSeq) } + test("repartition") { + checkAnswer( + testData.select('key).repartition(10).select('key), + testData.select('key).collect().toSeq) + } + test("agg") { checkAnswer( - testData2.groupBy('a)('a, Sum('b)), + testData2.groupBy('a)('a, sum('b)), Seq((1,3),(2,3),(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)), + testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), 9 ) checkAnswer( - testData2.aggregate(Sum('b)), + testData2.aggregate(sum('b)), 9 ) } @@ -98,19 +103,19 @@ class DslQuerySuite extends QueryTest { Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) checkAnswer( - arrayData.orderBy(GetItem('data, 0).asc), + arrayData.orderBy('data.getItem(0).asc), arrayData.collect().sortBy(_.data(0)).toSeq) checkAnswer( - arrayData.orderBy(GetItem('data, 0).desc), + arrayData.orderBy('data.getItem(0).desc), arrayData.collect().sortBy(_.data(0)).reverse.toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).asc), + mapData.orderBy('data.getItem(1).asc), mapData.collect().sortBy(_.data(1)).toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).desc), + mapData.orderBy('data.getItem(1).desc), mapData.collect().sortBy(_.data(1)).reverse.toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0743cfe8cff0f..6736189c96d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -36,6 +36,21 @@ class SQLQuerySuite extends QueryTest { "test") } + test("SPARK-2407 Added Parser of SQL SUBSTR()") { + checkAnswer( + sql("SELECT substr(tableName, 1, 2) FROM tableName"), + "te") + checkAnswer( + sql("SELECT substr(tableName, 3) FROM tableName"), + "st") + checkAnswer( + sql("SELECT substring(tableName, 1, 2) FROM tableName"), + "te") + checkAnswer( + sql("SELECT substring(tableName, 3) FROM tableName"), + "st") + } + test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 6f0d46d816266..5f61fb5e16ea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -22,14 +22,15 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN) - testColumnStats(classOf[ByteColumnStats], BYTE) - testColumnStats(classOf[ShortColumnStats], SHORT) - testColumnStats(classOf[IntColumnStats], INT) - testColumnStats(classOf[LongColumnStats], LONG) - testColumnStats(classOf[FloatColumnStats], FLOAT) - testColumnStats(classOf[DoubleColumnStats], DOUBLE) - testColumnStats(classOf[StringColumnStats], STRING) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN) + testColumnStats(classOf[ByteColumnStats], BYTE) + testColumnStats(classOf[ShortColumnStats], SHORT) + testColumnStats(classOf[IntColumnStats], INT) + testColumnStats(classOf[LongColumnStats], LONG) + testColumnStats(classOf[FloatColumnStats], FLOAT) + testColumnStats(classOf[DoubleColumnStats], DOUBLE) + testColumnStats(classOf[StringColumnStats], STRING) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP) def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 314b7d317ed75..829342215e691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import java.sql.Timestamp import org.scalatest.FunSuite @@ -32,7 +33,7 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, STRING -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -52,14 +53,15 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) - checkActualSize(SHORT, Short.MaxValue, 2) - checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) checkActualSize(BINARY, binary, 4 + 4) @@ -188,17 +190,7 @@ class ColumnTypeSuite extends FunSuite with Logging { } private def hexDump(value: Any): String = { - if (value.isInstanceOf[String]) { - val sb = new StringBuilder() - for (ch <- value.asInstanceOf[String].toCharArray) { - sb.append(Integer.toHexString(ch & 0xffff)).append(' ') - } - if (! sb.isEmpty) sb.setLength(sb.length - 1) - sb.toString() - } else { - // for now .. - hexDump(value.toString) - } + value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") } private def dumpBuffer(buff: ByteBuffer): Any = { @@ -207,7 +199,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val b = buff.get() sb.append(Integer.toHexString(b & 0xff)).append(' ') } - if (! sb.isEmpty) sb.setLength(sb.length - 1) + if (sb.nonEmpty) sb.setLength(sb.length - 1) sb.toString() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 04bdc43d95328..38b04dd959f70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random +import java.sql.Timestamp + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.{DataType, NativeType} @@ -39,15 +41,19 @@ object ColumnarTestUtils { } (columnType match { - case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte - case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort - case INT => Random.nextInt() - case LONG => Random.nextLong() - case FLOAT => Random.nextFloat() - case DOUBLE => Random.nextDouble() - case STRING => Random.nextString(Random.nextInt(32)) - case BOOLEAN => Random.nextBoolean() - case BINARY => randomBytes(Random.nextInt(32)) + case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte + case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort + case INT => Random.nextInt() + case LONG => Random.nextLong() + case FLOAT => Random.nextFloat() + case DOUBLE => Random.nextDouble() + case STRING => Random.nextString(Random.nextInt(32)) + case BOOLEAN => Random.nextBoolean() + case BINARY => randomBytes(Random.nextInt(32)) + case TIMESTAMP => + val timestamp = new Timestamp(Random.nextLong()) + timestamp.setNanos(Random.nextInt(999999999)) + timestamp case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) @@ -96,5 +102,4 @@ object ColumnarTestUtils { (values, rows) } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala similarity index 94% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cdfc2d0c17384..c69e93ba2b9ba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.util.TimeZone import org.scalatest.BeforeAndAfter @@ -31,14 +32,20 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { lazy val hiveQueryDir = TestHive.getHiveFile("ql" + File.separator + "src" + File.separator + "test" + File.separator + "queries" + File.separator + "clientpositive") + var originalTimeZone: TimeZone = _ + def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + originalTimeZone = TimeZone.getDefault + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) } override def afterAll() { TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -84,13 +91,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_java_method", "create_merge_compressed", + // DFS commands + "symlink_text_input_format", + // Weird DDL differences result in failures on jenkins. "create_like2", "create_view_translate", "partitions_json", - // Timezone specific test answers. - "udf_unix_timestamp", + // This test is totally fine except that it includes wrong queries and expects errors, but error + // message format in Hive and Spark SQL differ. Should workaround this later. "udf_to_unix_timestamp", // Cant run without local map/reduce. @@ -186,7 +196,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive returns the results of describe as plain text. Comments with multiple lines // introduce extra lines in the Hive results, which make the result comparison fail. - "describe_comment_indent" + "describe_comment_indent", + + // Limit clause without a ordering, which causes failure. + "orc_predicate_pushdown" ) /** @@ -278,7 +291,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_string", "compute_stats_table", "convert_enum_to_string", + "correlationoptimizer1", + "correlationoptimizer10", "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", "correlationoptimizer15", "correlationoptimizer2", "correlationoptimizer3", @@ -286,6 +303,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "correlationoptimizer6", "correlationoptimizer7", "correlationoptimizer8", + "correlationoptimizer9", "count", "cp_mj_rc", "create_insert_outputformat", @@ -296,6 +314,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ct_case_insensitive", "database_location", "database_properties", + "decimal_1", "decimal_4", "decimal_join", "default_partition_name", @@ -304,6 +323,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "describe_formatted_view_partitioned", "diff_part_input_formats", "disable_file_format_check", + "disallow_incompatible_type_change_off", "drop_function", "drop_index", "drop_multi_partitions", @@ -359,8 +379,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_map_ppr", "groupby_multi_insert_common_distinct", "groupby_multi_single_reducer2", + "groupby_multi_single_reducer3", "groupby_mutli_insert_common_distinct", "groupby_neg_float", + "groupby_ppd", "groupby_ppr", "groupby_sort_10", "groupby_sort_2", @@ -372,6 +394,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_8", "groupby_sort_9", "groupby_sort_test_1", + "having", + "having1", "implicit_cast1", "innerjoin", "inoutdriver", @@ -400,6 +424,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "input4", "input40", "input41", + "input49", "input4_cb_delim", "input6", "input7", @@ -481,6 +506,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_hive_626", "join_map_ppr", "join_nulls", + "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -491,6 +517,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "lateral_view_ppd", "leftsemijoin", "leftsemijoin_mr", + "limit_pushdown_negative", "lineage1", "literal_double", "literal_ints", @@ -598,6 +625,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", + "reduce_deduplicate_extended", "reducesink_dedup", "rename_column", "router_join_ppr", @@ -646,7 +674,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", + "timestamp_1", + "timestamp_2", + "timestamp_3", "timestamp_comparison", + "timestamp_lazy", + "timestamp_null", + "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", @@ -704,6 +738,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_double", "udf_E", "udf_elt", + "udf_equal", "udf_exp", "udf_field", "udf_find_in_set", @@ -786,6 +821,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_translate", "udf_trim", "udf_ucase", + "udf_unix_timestamp", "udf_upper", "udf_var_pop", "udf_var_samp", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index f30ae28b81e06..1699ffe06ce15 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -102,6 +102,36 @@ test + + + + hive + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + compatibility/src/test/scala + + + + + + + + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7aedfcd74189b..201c85f3d501e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.sql.Timestamp import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ @@ -28,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.io.TimestampWritable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -251,7 +253,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, TimestampType) + ShortType, DecimalType, TimestampType, BinaryType) protected def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -266,6 +268,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" + case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString + case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala new file mode 100644 index 0000000000000..ad7dc0ecdb1bf --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.{io => hadoopIo} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types +import org.apache.spark.sql.catalyst.types._ + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[hive] trait HiveInspectors { + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + // writable + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType + + // java class + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[Array[Byte]] => BinaryType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + + // primitive type + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get + case b: hiveIo.ByteWritable => b.get + case b: hadoopIo.FloatWritable => b.get + case b: hadoopIo.BytesWritable => { + val bytes = new Array[Byte](b.getLength) + System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) + bytes + } + case t: hiveIo.TimestampWritable => t.getTimestamp + case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + case p: java.math.BigDecimal => p + case p: Array[Byte] => p + case p: java.sql.Timestamp => p + } + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case hvoi: HiveVarcharObjectInspector => + if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue + case hdoi: HiveDecimalObjectInspector => + if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case f: Float => f: java.lang.Float + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case b: BigDecimal => b.bigDecimal + case b: Array[Byte] => b + case t: java.sql.Timestamp => t + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + case _: WritableFloatObjectInspector => FloatType + case _: JavaFloatObjectInspector => FloatType + case _: WritableBinaryObjectInspector => BinaryType + case _: JavaBinaryObjectInspector => BinaryType + case _: WritableHiveDecimalObjectInspector => DecimalType + case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableTimestampObjectInspector => TimestampType + case _: JavaTimestampObjectInspector => TimestampType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case TimestampType => timestampTypeInfo + case NullType => voidTypeInfo + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f83068860701f..156b090712df2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -43,14 +43,15 @@ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { import HiveMetastoreTypes._ - val client = Hive.get(hive.hiveconf) + /** Connection to hive metastore. Usages should lock on `this`. */ + protected[hive] val client = Hive.get(hive.hiveconf) val caseSensitive: Boolean = false def lookupRelation( db: Option[String], tableName: String, - alias: Option[String]): LogicalPlan = { + alias: Option[String]): LogicalPlan = synchronized { val (dbName, tblName) = processDatabaseAndTableName(db, tableName) val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) val table = client.getTable(databaseName, tblName) @@ -257,7 +258,7 @@ private[hive] case class MetastoreRelation // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - def hiveQlTable = new Table(table) + @transient lazy val hiveQlTable = new Table(table) def hiveQlPartitions = partitions.map { p => new Partition(hiveQlTable, p) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b70104dd5be5a..e6ab68b563f8d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command -private[hive] case class AddJar(jarPath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ @@ -229,7 +227,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8)) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -860,6 +858,7 @@ private[hive] object HiveQl { val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r + val SUBSTR = "(?i)SUBSTR(?:ING)?".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -870,10 +869,7 @@ private[hive] object HiveQl { nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to - // find the underlying attribute references. - case GetItem(UnresolvedAttribute(qualifierName), ordinal) => - GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal) + case other => GetField(other, attr) } /* Stars (*) */ @@ -929,11 +925,14 @@ private[hive] object HiveQl { case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token(DIV(), left :: right:: Nil) => + Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) @@ -987,6 +986,10 @@ private[hive] object HiveQl { /* Other functions */ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => + Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => + Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ef8bae74530ec..e7016fa16eea9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -96,19 +96,9 @@ case class HiveTableScan( .getOrElse(sys.error(s"Can't find attribute $a")) val fieldObjectInspector = ref.getFieldObjectInspector - val unwrapHiveData = fieldObjectInspector match { - case _: HiveVarcharObjectInspector => - (value: Any) => value.asInstanceOf[HiveVarchar].getValue - case _: HiveDecimalObjectInspector => - (value: Any) => BigDecimal(value.asInstanceOf[HiveDecimal].bigDecimalValue()) - case _ => - identity[Any] _ - } - (row: Any, _: Array[String]) => { val data = objectInspector.getStructFieldData(row, ref) - val hiveData = unwrapData(data, fieldObjectInspector) - if (hiveData != null) unwrapHiveData(hiveData) else null + unwrapData(data, fieldObjectInspector) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 9b105308ab7cf..057eb60a02612 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry - extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { +private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( sys.error(s"Couldn't find function $name")) + val functionClassName = functionInfo.getFunctionClass.getName() + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = createFunction[UDF](name) + val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) HiveSimpleUdf( - name, + functionClassName, children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(name, children) + HiveGenericUdf(functionClassName, children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(name, children) + HiveGenericUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(name, Nil, children) + HiveGenericUdtf(functionClassName, Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } - - def javaClassToDataType(clz: Class[_]): DataType = clz match { - // writable - case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType - case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType - case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType - case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType - case c: Class[_] if c == classOf[hadoopIo.Text] => StringType - case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType - case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType - case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType - case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType - case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - - // java class - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType - case c: Class[_] if c == classOf[Array[Byte]] => BinaryType - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - - // primitive type - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) - } } private[hive] trait HiveFunctionFactory { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) - def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass - def createFunction[UDFType](name: String) = - getFunctionClass(name).newInstance.asInstanceOf[UDFType] - - /** Converts hive types to native catalyst types. */ - def unwrap(a: Any): Any = a match { - case null => null - case i: hadoopIo.IntWritable => i.get - case t: hadoopIo.Text => t.toString - case l: hadoopIo.LongWritable => l.get - case d: hadoopIo.DoubleWritable => d.get - case d: hiveIo.DoubleWritable => d.get - case s: hiveIo.ShortWritable => s.get - case b: hadoopIo.BooleanWritable => b.get - case b: hiveIo.ByteWritable => b.get - case b: hadoopIo.FloatWritable => b.get - case b: hadoopIo.BytesWritable => { - val bytes = new Array[Byte](b.getLength) - System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) - bytes - } - case t: hiveIo.TimestampWritable => t.getTimestamp - case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) - case list: java.util.List[_] => list.map(unwrap) - case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap - case array: Array[_] => array.map(unwrap).toSeq - case p: java.lang.Short => p - case p: java.lang.Long => p - case p: java.lang.Float => p - case p: java.lang.Integer => p - case p: java.lang.Double => p - case p: java.lang.Byte => p - case p: java.lang.Boolean => p - case str: String => str - case p: java.math.BigDecimal => p - case p: Array[Byte] => p - case p: java.sql.Timestamp => p - } + val functionClassName: String + + def createFunction[UDFType]() = + getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] } private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { @@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type UDFType type EvaluatedType = Any - val name: String - def nullable = true def references = children.flatMap(_.references).toSet - // FunctionInfo is not serializable so we must look it up here again. - lazy val functionInfo = getFunctionInfo(name) - lazy val function = createFunction[UDFType](name) + lazy val function = createFunction[UDFType]() - override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } -private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { +private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) + extends HiveUdf { + import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) } } -private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) +private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ @@ -277,127 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) } } -private[hive] trait HiveInspectors { - - def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { - case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) - case li: ListObjectInspector => - Option(li.getList(data)) - .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) - .orNull - case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k,v) => - (unwrapData(k, mi.getMapKeyObjectInspector), - unwrapData(v, mi.getMapValueObjectInspector)) - }.toMap).orNull - case si: StructObjectInspector => - val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) - } - - /** Converts native catalyst types to the types expected by Hive */ - def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal - case b: Array[Byte] => b - case t: java.sql.Timestamp => t - case s: Seq[_] => seqAsJavaList(s.map(wrap)) - case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) - case null => null - } - - def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => - ObjectInspectorFactory.getStandardMapObjectInspector( - toInspector(keyType), toInspector(valueType)) - case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector - case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector - case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector - case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector - case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector - case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector - case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector - case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector - case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector - case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector - case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => - ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) - } - - def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { - case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { - types.StructField( - f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) - case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) - case m: MapObjectInspector => - MapType( - inspectorToDataType(m.getMapKeyObjectInspector), - inspectorToDataType(m.getMapValueObjectInspector)) - case _: WritableStringObjectInspector => StringType - case _: JavaStringObjectInspector => StringType - case _: WritableIntObjectInspector => IntegerType - case _: JavaIntObjectInspector => IntegerType - case _: WritableDoubleObjectInspector => DoubleType - case _: JavaDoubleObjectInspector => DoubleType - case _: WritableBooleanObjectInspector => BooleanType - case _: JavaBooleanObjectInspector => BooleanType - case _: WritableLongObjectInspector => LongType - case _: JavaLongObjectInspector => LongType - case _: WritableShortObjectInspector => ShortType - case _: JavaShortObjectInspector => ShortType - case _: WritableByteObjectInspector => ByteType - case _: JavaByteObjectInspector => ByteType - case _: WritableFloatObjectInspector => FloatType - case _: JavaFloatObjectInspector => FloatType - case _: WritableBinaryObjectInspector => BinaryType - case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType - case _: WritableTimestampObjectInspector => TimestampType - case _: JavaTimestampObjectInspector => TimestampType - } - - implicit class typeInfoConversions(dt: DataType) { - import org.apache.hadoop.hive.serde2.typeinfo._ - import TypeInfoFactory._ - - def toTypeInfo: TypeInfo = dt match { - case BinaryType => binaryTypeInfo - case BooleanType => booleanTypeInfo - case ByteType => byteTypeInfo - case DoubleType => doubleTypeInfo - case FloatType => floatTypeInfo - case IntegerType => intTypeInfo - case LongType => longTypeInfo - case ShortType => shortTypeInfo - case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo - case TimestampType => timestampTypeInfo - case NullType => voidTypeInfo - } - } -} - private[hive] case class HiveGenericUdaf( - name: String, + functionClassName: String, children: Seq[Expression]) extends AggregateExpression with HiveInspectors with HiveFunctionFactory { @@ -405,7 +207,7 @@ private[hive] case class HiveGenericUdaf( type UDFType = AbstractGenericUDAFResolver @transient - protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction() @transient protected lazy val objectInspector = { @@ -422,9 +224,9 @@ private[hive] case class HiveGenericUdaf( def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" - def newInstance() = new HiveUdafFunction(name, children, this) + def newInstance() = new HiveUdafFunction(functionClassName, children, this) } /** @@ -439,7 +241,7 @@ private[hive] case class HiveGenericUdaf( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUdtf( - name: String, + functionClassName: String, aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { @@ -447,7 +249,7 @@ private[hive] case class HiveGenericUdtf( override def references = children.flatMap(_.references).toSet @transient - protected lazy val function: GenericUDTF = createFunction(name) + protected lazy val function: GenericUDTF = createFunction() protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @@ -502,11 +304,11 @@ private[hive] case class HiveGenericUdtf( } } - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } private[hive] case class HiveUdafFunction( - functionName: String, + functionClassName: String, exprs: Seq[Expression], base: AggregateExpression) extends AggregateFunction @@ -515,7 +317,7 @@ private[hive] case class HiveUdafFunction( def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + private val resolver = createFunction[AbstractGenericUDAFResolver]() private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 new file mode 100644 index 0000000000000..4d1ebdcde2c71 --- /dev/null +++ b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 @@ -0,0 +1 @@ +true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-13-f9f839aedb3a350719c0cbc53a06ace5 b/sql/hive/src/test/resources/golden/correlationoptimizer1-13-f9f839aedb3a350719c0cbc53a06ace5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 b/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 new file mode 100644 index 0000000000000..235736a2807b6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 @@ -0,0 +1 @@ +3556 15 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-16-43f356b36962f2bade5706d8cf5ae6b4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-16-43f356b36962f2bade5706d8cf5ae6b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 b/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 new file mode 100644 index 0000000000000..235736a2807b6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 @@ -0,0 +1 @@ +3556 15 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-2-42a0eedc3751f792ad5438b2c64d3897 b/sql/hive/src/test/resources/golden/correlationoptimizer1-2-42a0eedc3751f792ad5438b2c64d3897 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-20-b4c1e27a7c1f61a3e9ae07c80e6e2973 b/sql/hive/src/test/resources/golden/correlationoptimizer1-20-b4c1e27a7c1f61a3e9ae07c80e6e2973 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 b/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-23-2cdc77fd60449f3547cf95d8eb09a2 b/sql/hive/src/test/resources/golden/correlationoptimizer1-23-2cdc77fd60449f3547cf95d8eb09a2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 b/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-26-8bcdcc5f01508f576d7bd6422c939225 b/sql/hive/src/test/resources/golden/correlationoptimizer1-26-8bcdcc5f01508f576d7bd6422c939225 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e b/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-29-941ecfef9448ecff56cc16bcfb233ee4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-29-941ecfef9448ecff56cc16bcfb233ee4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e b/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-32-ef6502d6b282c8a6d228bba395b24724 b/sql/hive/src/test/resources/golden/correlationoptimizer1-32-ef6502d6b282c8a6d228bba395b24724 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 new file mode 100644 index 0000000000000..5aa2d482094af --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 @@ -0,0 +1 @@ +79136 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-35-b79b220859c09354e23b533c105ccbab b/sql/hive/src/test/resources/golden/correlationoptimizer1-35-b79b220859c09354e23b533c105ccbab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 new file mode 100644 index 0000000000000..5aa2d482094af --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 @@ -0,0 +1 @@ +79136 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-38-638e5300f4c892c2bf27bd91a8f81b64 b/sql/hive/src/test/resources/golden/correlationoptimizer1-38-638e5300f4c892c2bf27bd91a8f81b64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad b/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad new file mode 100644 index 0000000000000..b4a3a9d327f47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad @@ -0,0 +1 @@ +3556 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-41-3514c74c7f68f2d70cc6d51ac46c20 b/sql/hive/src/test/resources/golden/correlationoptimizer1-41-3514c74c7f68f2d70cc6d51ac46c20 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad b/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad new file mode 100644 index 0000000000000..b4a3a9d327f47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad @@ -0,0 +1 @@ +3556 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-44-7490df6719cd7e47aa08dbcbc3266a92 b/sql/hive/src/test/resources/golden/correlationoptimizer1-44-7490df6719cd7e47aa08dbcbc3266a92 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba b/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba new file mode 100644 index 0000000000000..bb564e0fd06eb --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba @@ -0,0 +1 @@ +3556 510 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-47-73da9fe2b0c2ee26c021ec3f2fa27272 b/sql/hive/src/test/resources/golden/correlationoptimizer1-47-73da9fe2b0c2ee26c021ec3f2fa27272 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba b/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba new file mode 100644 index 0000000000000..bb564e0fd06eb --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba @@ -0,0 +1 @@ +3556 510 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-5-a1c80c68b9a7597096c2580c3766f7f7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-5-a1c80c68b9a7597096c2580c3766f7f7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-51-fcf9bcb522f542637ccdea863b408448 b/sql/hive/src/test/resources/golden/correlationoptimizer1-51-fcf9bcb522f542637ccdea863b408448 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 b/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 new file mode 100644 index 0000000000000..edd216f16b190 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 @@ -0,0 +1 @@ +3556 661329102 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-54-dad56e1f06c808b29e5dc8fb0c49efb2 b/sql/hive/src/test/resources/golden/correlationoptimizer1-54-dad56e1f06c808b29e5dc8fb0c49efb2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 b/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 new file mode 100644 index 0000000000000..edd216f16b190 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 @@ -0,0 +1 @@ +3556 661329102 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-57-3cd3fbbbd8ee5c274fe3d6a45126cef4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-57-3cd3fbbbd8ee5c274fe3d6a45126cef4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 b/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 new file mode 100644 index 0000000000000..9fde6099b4f08 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 @@ -0,0 +1 @@ +2835 29 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-60-d6bbaf0d40010159095e4cac025c50c5 b/sql/hive/src/test/resources/golden/correlationoptimizer1-60-d6bbaf0d40010159095e4cac025c50c5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 b/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 new file mode 100644 index 0000000000000..9fde6099b4f08 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 @@ -0,0 +1 @@ +2835 29 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 b/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-9-d5bea91b4edb8be0428a336ff9c21dde b/sql/hive/src/test/resources/golden/correlationoptimizer1-9-d5bea91b4edb8be0428a336ff9c21dde new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-11-a1b7af95dfd01783c07aa23208d6160 b/sql/hive/src/test/resources/golden/correlationoptimizer10-11-a1b7af95dfd01783c07aa23208d6160 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 b/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 new file mode 100644 index 0000000000000..9ea431d9f5d18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 @@ -0,0 +1,5 @@ +66 val_66 +98 val_98 +128 +146 val_146 +150 val_150 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-14-934b668c11600dc9c013c2ddc4c0d68c b/sql/hive/src/test/resources/golden/correlationoptimizer10-14-934b668c11600dc9c013c2ddc4c0d68c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa b/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa new file mode 100644 index 0000000000000..5abecc5df25fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa @@ -0,0 +1,23 @@ +181 val_181 +183 val_183 +186 val_186 +187 val_187 +187 val_187 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +191 val_191 +192 val_192 +193 val_193 +193 val_193 +193 val_193 +194 val_194 +195 val_195 +195 val_195 +196 val_196 +197 val_197 +197 val_197 +199 val_199 +199 val_199 +199 val_199 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-17-9d5521ecef1353d23fd2d4f7d78e7006 b/sql/hive/src/test/resources/golden/correlationoptimizer10-17-9d5521ecef1353d23fd2d4f7d78e7006 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa b/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa new file mode 100644 index 0000000000000..5abecc5df25fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa @@ -0,0 +1,23 @@ +181 val_181 +183 val_183 +186 val_186 +187 val_187 +187 val_187 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +191 val_191 +192 val_192 +193 val_193 +193 val_193 +193 val_193 +194 val_194 +195 val_195 +195 val_195 +196 val_196 +197 val_197 +197 val_197 +199 val_199 +199 val_199 +199 val_199 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-2-f9de06a4184ab1f42793327c1497437 b/sql/hive/src/test/resources/golden/correlationoptimizer10-2-f9de06a4184ab1f42793327c1497437 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea b/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea new file mode 100644 index 0000000000000..d00aeb4be0340 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea @@ -0,0 +1,15 @@ +66 1 +98 1 +128 1 +146 1 +150 1 +213 1 +224 1 +238 1 +255 1 +273 1 +278 1 +311 1 +369 1 +401 1 +406 1 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-5-6f191930802d659058465b2e6de08dd3 b/sql/hive/src/test/resources/golden/correlationoptimizer10-5-6f191930802d659058465b2e6de08dd3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea b/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea new file mode 100644 index 0000000000000..d00aeb4be0340 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea @@ -0,0 +1,15 @@ +66 1 +98 1 +128 1 +146 1 +150 1 +213 1 +224 1 +238 1 +255 1 +273 1 +278 1 +311 1 +369 1 +401 1 +406 1 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-8-b2c8d0056b9f1b41cdaeaab4a1a5c9ec b/sql/hive/src/test/resources/golden/correlationoptimizer10-8-b2c8d0056b9f1b41cdaeaab4a1a5c9ec new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 b/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 new file mode 100644 index 0000000000000..9ea431d9f5d18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 @@ -0,0 +1,5 @@ +66 val_66 +98 val_98 +128 +146 val_146 +150 val_150 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d b/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 new file mode 100644 index 0000000000000..17c838bb62b3b --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 @@ -0,0 +1,9 @@ +103 val_103 103 val_103 4 4 +104 val_104 104 val_104 4 4 +105 val_105 105 val_105 1 1 +111 val_111 111 val_111 1 1 +113 val_113 113 val_113 4 4 +114 val_114 114 val_114 1 1 +116 val_116 116 val_116 1 1 +118 val_118 118 val_118 4 4 +119 val_119 119 val_119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 new file mode 100644 index 0000000000000..17c838bb62b3b --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 @@ -0,0 +1,9 @@ +103 val_103 103 val_103 4 4 +104 val_104 104 val_104 4 4 +105 val_105 105 val_105 1 1 +111 val_111 111 val_111 1 1 +113 val_113 113 val_113 4 4 +114 val_114 114 val_114 1 1 +116 val_116 116 val_116 1 1 +118 val_118 118 val_118 4 4 +119 val_119 119 val_119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b b/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 new file mode 100644 index 0000000000000..248a14f1f4a9f --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 @@ -0,0 +1,9 @@ +103 103 4 4 +104 104 4 4 +105 105 1 1 +111 111 1 1 +113 113 4 4 +114 114 1 1 +116 116 1 1 +118 118 4 4 +119 119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9 b/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 new file mode 100644 index 0000000000000..248a14f1f4a9f --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 @@ -0,0 +1,9 @@ +103 103 4 4 +104 104 4 4 +105 105 1 1 +111 111 1 1 +113 113 4 4 +114 114 1 1 +116 116 1 1 +118 118 4 4 +119 119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 new file mode 100644 index 0000000000000..17ba0bea723c6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 @@ -0,0 +1 @@ +0 0 0 1 2 diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 new file mode 100644 index 0000000000000..7b7a9175114ce --- /dev/null +++ b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 @@ -0,0 +1 @@ +2.0 0.5 0.3333333333333333 0.002 diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-0-4b116ec2d8fb55c52b7b0d248c616ae2 b/sql/hive/src/test/resources/golden/groupby_ppd-0-4b116ec2d8fb55c52b7b0d248c616ae2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-1-db7b1db8f5e61f0fa78d2874e4d72d9d b/sql/hive/src/test/resources/golden/groupby_ppd-1-db7b1db8f5e61f0fa78d2874e4d72d9d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-2-d286410aa1d5f5c8d91b863a6d6e29c5 b/sql/hive/src/test/resources/golden/groupby_ppd-2-d286410aa1d5f5c8d91b863a6d6e29c5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 b/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 new file mode 100644 index 0000000000000..704f1e62f14c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 @@ -0,0 +1,10 @@ +4 +4 +5 +4 +5 +5 +4 +4 +5 +4 diff --git a/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 b/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e new file mode 100644 index 0000000000000..b56757a60f780 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e @@ -0,0 +1,308 @@ +0 val_0 +2 val_2 +4 val_4 +5 val_5 +8 val_8 +9 val_9 +10 val_10 +11 val_11 +12 val_12 +15 val_15 +17 val_17 +18 val_18 +19 val_19 +20 val_20 +24 val_24 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +100 val_100 +103 val_103 +104 val_104 +105 val_105 +111 val_111 +113 val_113 +114 val_114 +116 val_116 +118 val_118 +119 val_119 +120 val_120 +125 val_125 +126 val_126 +128 val_128 +129 val_129 +131 val_131 +133 val_133 +134 val_134 +136 val_136 +137 val_137 +138 val_138 +143 val_143 +145 val_145 +146 val_146 +149 val_149 +150 val_150 +152 val_152 +153 val_153 +155 val_155 +156 val_156 +157 val_157 +158 val_158 +160 val_160 +162 val_162 +163 val_163 +164 val_164 +165 val_165 +166 val_166 +167 val_167 +168 val_168 +169 val_169 +170 val_170 +172 val_172 +174 val_174 +175 val_175 +176 val_176 +177 val_177 +178 val_178 +179 val_179 +180 val_180 +181 val_181 +183 val_183 +186 val_186 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +192 val_192 +193 val_193 +194 val_194 +195 val_195 +196 val_196 +197 val_197 +199 val_199 +200 val_200 +201 val_201 +202 val_202 +203 val_203 +205 val_205 +207 val_207 +208 val_208 +209 val_209 +213 val_213 +214 val_214 +216 val_216 +217 val_217 +218 val_218 +219 val_219 +221 val_221 +222 val_222 +223 val_223 +224 val_224 +226 val_226 +228 val_228 +229 val_229 +230 val_230 +233 val_233 +235 val_235 +237 val_237 +238 val_238 +239 val_239 +241 val_241 +242 val_242 +244 val_244 +247 val_247 +248 val_248 +249 val_249 +252 val_252 +255 val_255 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 b/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff new file mode 100644 index 0000000000000..2d7022e386303 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff @@ -0,0 +1,199 @@ +4 +5 +8 +9 +26 +27 +28 +30 +33 +34 +35 +37 +41 +42 +43 +44 +47 +51 +53 +54 +57 +58 +64 +65 +66 +67 +69 +70 +72 +74 +76 +77 +78 +80 +82 +83 +84 +85 +86 +87 +90 +92 +95 +96 +97 +98 +256 +257 +258 +260 +262 +263 +265 +266 +272 +273 +274 +275 +277 +278 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +291 +292 +296 +298 +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 b/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e new file mode 100644 index 0000000000000..bd545ccf7430c --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e @@ -0,0 +1,125 @@ +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 b/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff new file mode 100644 index 0000000000000..d77586c12b6af --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff @@ -0,0 +1,199 @@ +4 val_4 +5 val_5 +8 val_8 +9 val_9 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +302 val_302 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99 b/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 new file mode 100644 index 0000000000000..31c409082cc2f --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 @@ -0,0 +1,2 @@ +NULL 10 10 NULL NULL 10 +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1 b/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e new file mode 100644 index 0000000000000..9b77d13cbaab2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e @@ -0,0 +1,4 @@ +NULL NULL NULL NULL NULL NULL +NULL 10 10 NULL NULL 10 +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 new file mode 100644 index 0000000000000..47c0709d39851 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 @@ -0,0 +1,12 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c new file mode 100644 index 0000000000000..36ba48516b658 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c @@ -0,0 +1,12 @@ +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da new file mode 100644 index 0000000000000..fc1fd198cf8be --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da @@ -0,0 +1,13 @@ +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8 b/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429 b/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7 b/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224 b/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e b/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985 b/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746 b/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 new file mode 100644 index 0000000000000..2efbef0484452 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 @@ -0,0 +1,14 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd b/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 new file mode 100644 index 0000000000000..ea001a222f357 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 @@ -0,0 +1,40 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c new file mode 100644 index 0000000000000..ea001a222f357 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c @@ -0,0 +1,40 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 new file mode 100644 index 0000000000000..1093bd89f6e3f --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 +110 NULL NULL 110 +148 NULL NULL NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc new file mode 100644 index 0000000000000..9cf0036674d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL NULL 135 +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3 b/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 new file mode 100644 index 0000000000000..421049d6e509e --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 @@ -0,0 +1,9 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f b/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 new file mode 100644 index 0000000000000..aec3122cae5f9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 @@ -0,0 +1,2 @@ +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c b/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac new file mode 100644 index 0000000000000..30db79efa79b4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac @@ -0,0 +1,29 @@ +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL 10 +NULL NULL NULL NULL NULL 35 +NULL NULL 10 NULL NULL NULL +NULL NULL 10 NULL NULL 10 +NULL NULL 10 NULL NULL 35 +NULL NULL 48 NULL NULL NULL +NULL NULL 48 NULL NULL 10 +NULL NULL 48 NULL NULL 35 +NULL 10 NULL NULL NULL NULL +NULL 10 NULL NULL NULL 10 +NULL 10 NULL NULL NULL 35 +NULL 10 10 NULL NULL NULL +NULL 10 10 NULL NULL 10 +NULL 10 10 NULL NULL 35 +NULL 10 48 NULL NULL NULL +NULL 10 48 NULL NULL 10 +NULL 10 48 NULL NULL 35 +NULL 35 NULL NULL NULL NULL +NULL 35 NULL NULL NULL 10 +NULL 35 NULL NULL NULL 35 +NULL 35 10 NULL NULL NULL +NULL 35 10 NULL NULL 10 +NULL 35 10 NULL NULL 35 +NULL 35 48 NULL NULL NULL +NULL 35 48 NULL NULL 10 +NULL 35 48 NULL NULL 35 +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809 b/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb b/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-1-f87339637a48bd1533493ebbed5432a7 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-1-f87339637a48bd1533493ebbed5432a7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-2-de7e5ac581b870fff10dc82c75c1c79e b/sql/hive/src/test/resources/golden/limit_pushdown_negative-2-de7e5ac581b870fff10dc82c75c1c79e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-3-be440c3f959ca53b758481aa90551984 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-3-be440c3f959ca53b758481aa90551984 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-4-4dedc8057d76af264c198beaacd7f000 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-4-4dedc8057d76af264c198beaacd7f000 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-5-543a20e69bd8987bc37a22c1c7ef33f1 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-5-543a20e69bd8987bc37a22c1c7ef33f1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-6-3f8274466914ad200b33a2c83fa6dab5 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-6-3f8274466914ad200b33a2c83fa6dab5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-7-fb7bf3783d4fb43673a202c4111d9092 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-7-fb7bf3783d4fb43673a202c4111d9092 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-10-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-10-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..b77c91ccee412 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-10-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.293872461E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-11-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-11-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..bc2423dc08a1b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-11-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-12-6328d3b3dfd295dd5ec453ffb47ff4d0 b/sql/hive/src/test/resources/golden/timestamp_1-12-6328d3b3dfd295dd5ec453ffb47ff4d0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-13-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-13-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-13-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-14-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-14-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-14-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-15-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-15-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-15-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-16-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-16-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-16-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-17-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-17-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-17-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-18-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-18-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-18-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-19-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-19-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..b77c91ccee412 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-19-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.293872461E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-20-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-20-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..bc2423dc08a1b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-20-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-21-d8fff1a6c464e50eb955babfafb0b98e b/sql/hive/src/test/resources/golden/timestamp_1-21-d8fff1a6c464e50eb955babfafb0b98e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-22-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-22-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-22-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-23-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-23-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-23-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-24-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-24-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-24-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-25-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-25-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-25-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-26-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-26-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-26-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-27-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-27-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-27-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-28-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-28-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..2158e77f78768 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-28-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.2938724611E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-29-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-29-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..4bdd802b9cda9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-29-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01.1 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-3-819633b45e3e1779bca6bcb7b77fe5a1 b/sql/hive/src/test/resources/golden/timestamp_1-3-819633b45e3e1779bca6bcb7b77fe5a1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-30-273256141c33eb88194cad22eb940d21 b/sql/hive/src/test/resources/golden/timestamp_1-30-273256141c33eb88194cad22eb940d21 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-31-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-31-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-31-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-32-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-32-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-32-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-33-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-33-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-33-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-34-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-34-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-34-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-35-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-35-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-35-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-36-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-36-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-36-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-37-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-37-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..b71ff60863360 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-37-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.2938724610001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-38-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-38-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..8b014c4cd8d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-38-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01.0001 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-39-b2fe5cc7c8ee62d3bb0c120c9a6c305d b/sql/hive/src/test/resources/golden/timestamp_1-39-b2fe5cc7c8ee62d3bb0c120c9a6c305d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-4-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-4-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-4-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-40-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-40-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-40-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-41-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-41-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-41-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-42-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-42-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-42-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-43-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-43-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-43-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-44-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-44-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-44-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-45-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-45-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-45-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-46-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-46-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..b71ff60863360 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-46-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.2938724610001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-47-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-47-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..8b014c4cd8d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-47-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01.0001 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-48-7029255241de8e8b9710801319990044 b/sql/hive/src/test/resources/golden/timestamp_1-48-7029255241de8e8b9710801319990044 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-49-90269c1e50c7ae8e75ca9cc297982135 b/sql/hive/src/test/resources/golden/timestamp_1-49-90269c1e50c7ae8e75ca9cc297982135 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-49-90269c1e50c7ae8e75ca9cc297982135 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_1-5-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-5-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-5-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-50-e6bfca320c4ee3aff39cf2f179d57da6 b/sql/hive/src/test/resources/golden/timestamp_1-50-e6bfca320c4ee3aff39cf2f179d57da6 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-50-e6bfca320c4ee3aff39cf2f179d57da6 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-51-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-51-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-51-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-52-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-52-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-52-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-53-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-53-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-53-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-54-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-54-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-54-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-55-343c75daac6695917608c17db8bf473e b/sql/hive/src/test/resources/golden/timestamp_1-55-343c75daac6695917608c17db8bf473e new file mode 100644 index 0000000000000..3eefb349894ac --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-55-343c75daac6695917608c17db8bf473e @@ -0,0 +1 @@ +1.293872461001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-56-cf19f7359a6d3456c4526b2c69f92d6a b/sql/hive/src/test/resources/golden/timestamp_1-56-cf19f7359a6d3456c4526b2c69f92d6a new file mode 100644 index 0000000000000..acce9d97b5eba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-56-cf19f7359a6d3456c4526b2c69f92d6a @@ -0,0 +1 @@ +2011-01-01 01:01:01.001000011 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-57-d362501d0176855077e65f8faf067fa8 b/sql/hive/src/test/resources/golden/timestamp_1-57-d362501d0176855077e65f8faf067fa8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_1-6-d0291a9bd42054b2732cb4f54cf39ae7 b/sql/hive/src/test/resources/golden/timestamp_1-6-d0291a9bd42054b2732cb4f54cf39ae7 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-6-d0291a9bd42054b2732cb4f54cf39ae7 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-7-e7b398d2a8107a42419c83771bda41e6 b/sql/hive/src/test/resources/golden/timestamp_1-7-e7b398d2a8107a42419c83771bda41e6 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-7-e7b398d2a8107a42419c83771bda41e6 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-8-a3eeec08bccae78d0d94ad2cb923e1cf b/sql/hive/src/test/resources/golden/timestamp_1-8-a3eeec08bccae78d0d94ad2cb923e1cf new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-8-a3eeec08bccae78d0d94ad2cb923e1cf @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_1-9-67f274bf16de625cf4e85af0c6185cac b/sql/hive/src/test/resources/golden/timestamp_1-9-67f274bf16de625cf4e85af0c6185cac new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_1-9-67f274bf16de625cf4e85af0c6185cac @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-10-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-10-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..b77c91ccee412 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-10-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.293872461E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-11-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-11-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..bc2423dc08a1b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-11-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-12-7350308cbf49d6ebd6599d3802750acd b/sql/hive/src/test/resources/golden/timestamp_2-12-7350308cbf49d6ebd6599d3802750acd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-13-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-13-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-13-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-14-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-14-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-14-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-15-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-15-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-15-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-16-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-16-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-16-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-17-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-17-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-17-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-18-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-18-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-18-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-19-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-19-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..b77c91ccee412 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-19-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.293872461E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-20-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-20-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..bc2423dc08a1b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-20-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-21-5eb58e5d3c5b9f766f0b497bf59c47b b/sql/hive/src/test/resources/golden/timestamp_2-21-5eb58e5d3c5b9f766f0b497bf59c47b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-22-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-22-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-22-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-23-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-23-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-23-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-24-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-24-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-24-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-25-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-25-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-25-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-26-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-26-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-26-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-27-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-27-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-27-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-28-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-28-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..2158e77f78768 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-28-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.2938724611E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-29-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-29-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..4bdd802b9cda9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-29-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01.1 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-3-a95a52c3a66e1f211ea04a0a10bd3b74 b/sql/hive/src/test/resources/golden/timestamp_2-3-a95a52c3a66e1f211ea04a0a10bd3b74 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-30-ffe6b6ddaaba84152074f7781fba2243 b/sql/hive/src/test/resources/golden/timestamp_2-30-ffe6b6ddaaba84152074f7781fba2243 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-31-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-31-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-31-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-32-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-32-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-32-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-33-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-33-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-33-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-34-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-34-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-34-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-35-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-35-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-35-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-36-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-36-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-36-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-37-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-37-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..b71ff60863360 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-37-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.2938724610001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-38-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-38-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..8b014c4cd8d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-38-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01.0001 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-39-8236608f28681eac5503195096a34181 b/sql/hive/src/test/resources/golden/timestamp_2-39-8236608f28681eac5503195096a34181 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-4-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-4-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-4-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-40-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-40-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-40-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-41-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-41-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-41-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-42-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-42-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-42-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-43-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-43-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-43-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-44-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-44-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-44-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-45-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-45-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-45-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-46-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-46-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..b71ff60863360 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-46-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.2938724610001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-47-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-47-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..8b014c4cd8d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-47-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01.0001 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-48-654e5533ec6dc911996abc7e47af8ccb b/sql/hive/src/test/resources/golden/timestamp_2-48-654e5533ec6dc911996abc7e47af8ccb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-49-25f6ec69328af6cba76899194e0dd84e b/sql/hive/src/test/resources/golden/timestamp_2-49-25f6ec69328af6cba76899194e0dd84e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-49-25f6ec69328af6cba76899194e0dd84e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_2-5-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-5-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-5-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-50-93c769be4cff93bea6e62bfe4e2a8742 b/sql/hive/src/test/resources/golden/timestamp_2-50-93c769be4cff93bea6e62bfe4e2a8742 new file mode 100644 index 0000000000000..987e7ca9a76df --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-50-93c769be4cff93bea6e62bfe4e2a8742 @@ -0,0 +1 @@ +77 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-51-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-51-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-51-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-52-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-52-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-52-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-53-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-53-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-53-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-54-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-54-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-54-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-55-5181279a0bf8939fe46ddacae015dad8 b/sql/hive/src/test/resources/golden/timestamp_2-55-5181279a0bf8939fe46ddacae015dad8 new file mode 100644 index 0000000000000..3eefb349894ac --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-55-5181279a0bf8939fe46ddacae015dad8 @@ -0,0 +1 @@ +1.293872461001E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-56-240fce5f58794fa051824e8732c00c03 b/sql/hive/src/test/resources/golden/timestamp_2-56-240fce5f58794fa051824e8732c00c03 new file mode 100644 index 0000000000000..acce9d97b5eba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-56-240fce5f58794fa051824e8732c00c03 @@ -0,0 +1 @@ +2011-01-01 01:01:01.001000011 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-57-ea7192a4a5a985bcc8aab9aa79d9f028 b/sql/hive/src/test/resources/golden/timestamp_2-57-ea7192a4a5a985bcc8aab9aa79d9f028 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_2-6-5bdbf67419cc060b82d091d80ce59bf9 b/sql/hive/src/test/resources/golden/timestamp_2-6-5bdbf67419cc060b82d091d80ce59bf9 new file mode 100644 index 0000000000000..8f3a49478b26b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-6-5bdbf67419cc060b82d091d80ce59bf9 @@ -0,0 +1 @@ +-4787 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-7-de3c42ab06c17ae895fd7deaf7bd9571 b/sql/hive/src/test/resources/golden/timestamp_2-7-de3c42ab06c17ae895fd7deaf7bd9571 new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-7-de3c42ab06c17ae895fd7deaf7bd9571 @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-8-da3937d21b7c2cfe1e624e812ae1d3ef b/sql/hive/src/test/resources/golden/timestamp_2-8-da3937d21b7c2cfe1e624e812ae1d3ef new file mode 100644 index 0000000000000..211a320b9ca72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-8-da3937d21b7c2cfe1e624e812ae1d3ef @@ -0,0 +1 @@ +1293872461 diff --git a/sql/hive/src/test/resources/golden/timestamp_2-9-252aebfe7882335d31bfc53a8705b7a b/sql/hive/src/test/resources/golden/timestamp_2-9-252aebfe7882335d31bfc53a8705b7a new file mode 100644 index 0000000000000..502f94a71edbd --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_2-9-252aebfe7882335d31bfc53a8705b7a @@ -0,0 +1 @@ +1.29387251E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 b/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 new file mode 100644 index 0000000000000..1b0a140b5a384 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 @@ -0,0 +1 @@ +1.3041352164485E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a b/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a new file mode 100644 index 0000000000000..d7ff6cd63d9f0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a @@ -0,0 +1 @@ +2011-04-29 20:46:56.4485 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-12-165256158e3db1ce19c3c9db3c8011d2 b/sql/hive/src/test/resources/golden/timestamp_3-12-165256158e3db1ce19c3c9db3c8011d2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_3-3-6143888a940bfcac1133330764f5a31a b/sql/hive/src/test/resources/golden/timestamp_3-3-6143888a940bfcac1133330764f5a31a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 b/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 b/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 new file mode 100644 index 0000000000000..21e72e8ac3d7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 @@ -0,0 +1 @@ +48 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 b/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 new file mode 100644 index 0000000000000..ee3be2941da6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 @@ -0,0 +1 @@ +-31184 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add b/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add new file mode 100644 index 0000000000000..1cf1952ac0372 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add @@ -0,0 +1 @@ +1304135216 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e b/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e new file mode 100644 index 0000000000000..1cf1952ac0372 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e @@ -0,0 +1 @@ +1304135216 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b b/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b new file mode 100644 index 0000000000000..d21deca762237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b @@ -0,0 +1 @@ +1.30413517E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_lazy-2-cdb72e0c24fd9277a41fe0c7b1392e34 b/sql/hive/src/test/resources/golden/timestamp_lazy-2-cdb72e0c24fd9277a41fe0c7b1392e34 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_lazy-3-79e0c72c4fb3b259dfbffd245ccaa636 b/sql/hive/src/test/resources/golden/timestamp_lazy-3-79e0c72c4fb3b259dfbffd245ccaa636 new file mode 100644 index 0000000000000..e6bfe0b1667ae --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_lazy-3-79e0c72c4fb3b259dfbffd245ccaa636 @@ -0,0 +1,5 @@ +2011-01-01 01:01:01 165 val_165 +2011-01-01 01:01:01 238 val_238 +2011-01-01 01:01:01 27 val_27 +2011-01-01 01:01:01 311 val_311 +2011-01-01 01:01:01 86 val_86 diff --git a/sql/hive/src/test/resources/golden/timestamp_lazy-4-b4c4417ce9f08baeb82ffde6ef1baa25 b/sql/hive/src/test/resources/golden/timestamp_lazy-4-b4c4417ce9f08baeb82ffde6ef1baa25 new file mode 100644 index 0000000000000..e6bfe0b1667ae --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_lazy-4-b4c4417ce9f08baeb82ffde6ef1baa25 @@ -0,0 +1,5 @@ +2011-01-01 01:01:01 165 val_165 +2011-01-01 01:01:01 238 val_238 +2011-01-01 01:01:01 27 val_27 +2011-01-01 01:01:01 311 val_311 +2011-01-01 01:01:01 86 val_86 diff --git a/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c b/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 b/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 new file mode 100644 index 0000000000000..9b9b6312a269a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 @@ -0,0 +1 @@ +a = b - Returns TRUE if a equals b and false otherwise diff --git a/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 new file mode 100644 index 0000000000000..30fdf50f62e4e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 @@ -0,0 +1,2 @@ +a = b - Returns TRUE if a equals b and false otherwise +Synonyms: == diff --git a/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 new file mode 100644 index 0000000000000..d6b4c860778b7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 @@ -0,0 +1 @@ +a == b - Returns TRUE if a equals b and false otherwise diff --git a/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 new file mode 100644 index 0000000000000..71e55d6d638a6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 @@ -0,0 +1,2 @@ +a == b - Returns TRUE if a equals b and false otherwise +Synonyms: = diff --git a/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 new file mode 100644 index 0000000000000..015c417bc68f0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 @@ -0,0 +1 @@ +false false true true NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 new file mode 100644 index 0000000000000..aa7b4b51edea7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 @@ -0,0 +1 @@ +a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e new file mode 100644 index 0000000000000..aa7b4b51edea7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e @@ -0,0 +1 @@ +a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 new file mode 100644 index 0000000000000..05292fb23192d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 @@ -0,0 +1 @@ +false false true true true false false false false diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-0-d555c8cd733572bfa8cd3362da9480cb b/sql/hive/src/test/resources/golden/udf_unix_timestamp-0-d555c8cd733572bfa8cd3362da9480cb new file mode 100644 index 0000000000000..9913d42ffc1aa --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-0-d555c8cd733572bfa8cd3362da9480cb @@ -0,0 +1 @@ +unix_timestamp([date[, pattern]]) - Returns the UNIX timestamp diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-1-8a9dbadae706047715cf5f903ff4a724 b/sql/hive/src/test/resources/golden/udf_unix_timestamp-1-8a9dbadae706047715cf5f903ff4a724 new file mode 100644 index 0000000000000..ef4aa8e9595d1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-1-8a9dbadae706047715cf5f903ff4a724 @@ -0,0 +1,2 @@ +unix_timestamp([date[, pattern]]) - Returns the UNIX timestamp +Converts the current or specified time to number of seconds since 1970-01-01. diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-2-28c40e51e55bed62693e626efda5d9c5 b/sql/hive/src/test/resources/golden/udf_unix_timestamp-2-28c40e51e55bed62693e626efda5d9c5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-3-732b21d386f2002b87eaf02d0b9951ed b/sql/hive/src/test/resources/golden/udf_unix_timestamp-3-732b21d386f2002b87eaf02d0b9951ed new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-4-b2e42ebb75cecf09961d36587797f6d0 b/sql/hive/src/test/resources/golden/udf_unix_timestamp-4-b2e42ebb75cecf09961d36587797f6d0 new file mode 100644 index 0000000000000..31aaa952b4cc5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-4-b2e42ebb75cecf09961d36587797f6d0 @@ -0,0 +1 @@ +2009-03-20 11:30:01 1237573801 diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-5-31243f5cb64356425b9f95ba011ac9d6 b/sql/hive/src/test/resources/golden/udf_unix_timestamp-5-31243f5cb64356425b9f95ba011ac9d6 new file mode 100644 index 0000000000000..6d9ee690cea4f --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-5-31243f5cb64356425b9f95ba011ac9d6 @@ -0,0 +1 @@ +2009-03-20 1237532400 diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-6-9b0f20bde1aaf9102b67a5498b167f31 b/sql/hive/src/test/resources/golden/udf_unix_timestamp-6-9b0f20bde1aaf9102b67a5498b167f31 new file mode 100644 index 0000000000000..e1d0cbb6c8495 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-6-9b0f20bde1aaf9102b67a5498b167f31 @@ -0,0 +1 @@ +2009 Mar 20 11:30:01 am 1237573801 diff --git a/sql/hive/src/test/resources/golden/udf_unix_timestamp-7-47f433ff6ccce4c666440cc1a228a96d b/sql/hive/src/test/resources/golden/udf_unix_timestamp-7-47f433ff6ccce4c666440cc1a228a96d new file mode 100644 index 0000000000000..6b40e687afd5d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unix_timestamp-7-47f433ff6ccce4c666440cc1a228a96d @@ -0,0 +1 @@ +random_string NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 08ef4d9b6bb93..b4dbf2b115799 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -350,12 +350,6 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") - println("hive output") - hive.foreach(println) - - println("catalyst printout") - catalyst.foreach(println) - if (recomputeCache) { logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a623d29b53973..a8623b64c656f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,18 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("boolean = number", + """ + |SELECT + | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y, + | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y, + | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y, + | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y, + | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y, + | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y + |FROM src LIMIT 1 + """.stripMargin) + test("CREATE TABLE AS runs once") { hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, @@ -40,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src WHERE key Between 1 and 2") createQueryTest("div", - "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1") + "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1") + + createQueryTest("division", + "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { assert(sql("SELECT 1").collect() === Array(Seq(1))) @@ -390,7 +405,7 @@ class HiveQuerySuite extends HiveComparisonTest { hql("CREATE TABLE m(value MAP)") hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { - case (Row(map: Map[Int, String]), Row(key: Int, value: String)) => + case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) } @@ -421,64 +436,63 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) => - key -> value - } + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. assert(hql("SET").collect().size == 0) - assertResult(Array(testKey -> testVal)) { - rowsToPairs(hql(s"SET $testKey=$testVal").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(hql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(testKey -> testVal)) { - rowsToPairs(hql("SET").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(hql("SET")) } hql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - rowsToPairs(hql("SET").collect()) + assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + collectResults(hql("SET")) } // "set key" - assertResult(Array(testKey -> testVal)) { - rowsToPairs(hql(s"SET $testKey").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(hql(s"SET $testKey")) } - assertResult(Array(nonexistentKey -> "")) { - rowsToPairs(hql(s"SET $nonexistentKey").collect()) + assertResult(Set(nonexistentKey -> "")) { + collectResults(hql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Array(testKey -> testVal)) { - rowsToPairs(sql(s"SET $testKey=$testVal").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(testKey -> testVal)) { - rowsToPairs(sql("SET").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - rowsToPairs(sql("SET").collect()) + assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + collectResults(sql("SET")) } - assertResult(Array(testKey -> testVal)) { - rowsToPairs(sql(s"SET $testKey").collect()) + assertResult(Set(testKey -> testVal)) { + collectResults(sql(s"SET $testKey")) } - assertResult(Array(nonexistentKey -> "")) { - rowsToPairs(sql(s"SET $nonexistentKey").collect()) + assertResult(Set(nonexistentKey -> "")) { + collectResults(sql(s"SET $nonexistentKey")) } clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 67594b57d3dfa..fb03db12a0b01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -case class Data(a: Int, B: Int, n: Nested) +case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) case class Nested(a: Int, B: Int) /** @@ -53,12 +53,18 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil) + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerAsTable("caseSensitivityTest") hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } + test("nested repeated resolution") { + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .registerAsTable("nestedRepeatedTest") + assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + } + /** * Negative examples. Currently only left here for documentation purposes. * TODO(marmbrus): Test that catalyst fails on these queries. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index c4bdf01fa3744..c00e11d11910f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -737,7 +737,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } object JavaPairDStream { - implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) = { + implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) + : JavaPairDStream[K, V] = { new JavaPairDStream[K, V](dstream) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 40da31318942e..1a47089e513c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -133,17 +133,17 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( val numOldValues = oldRDDs.size val numNewValues = newRDDs.size - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { + val mergeValues = (arrayOfValues: Array[Iterable[V]]) => { + if (arrayOfValues.size != 1 + numOldValues + numNewValues) { throw new Exception("Unexpected number of sequences of reduced values") } // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + val oldValues = (1 to numOldValues).map(i => arrayOfValues(i)).filter(!_.isEmpty).map(_.head) // Getting reduced values "new time steps" val newValues = - (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + (1 to numNewValues).map(i => arrayOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - if (seqOfValues(0).isEmpty) { + if (arrayOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { throw new Exception("Neither previous window has value for key, nor new values found. " + @@ -153,7 +153,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( newValues.reduce(reduceF) // return } else { // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + var tempValue = arrayOfValues(0).head // If old values exists, then inverse reduce then from previous value if (!oldValues.isEmpty) { tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) @@ -166,7 +166,8 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( } } - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K, Array[Iterable[V]])]] + .mapValues(mergeValues) if (filterFunc.isDefined) { Some(mergedValuesRDD.filter(filterFunc.get)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index 743be58950c09..1868a1ebc7b4a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -68,13 +68,13 @@ object ActorSupervisorStrategy { * should be same. */ @DeveloperApi -trait ActorHelper { +trait ActorHelper extends Logging{ self: Actor => // to ensure that this can be added to Actor classes only /** Store an iterator of received data as a data block into Spark's memory. */ def store[T](iter: Iterator[T]) { - println("Storing iterator") + logDebug("Storing iterator") context.parent ! IteratorData(iter) } @@ -84,6 +84,7 @@ trait ActorHelper { * that Spark is configured to use. */ def store(bytes: ByteBuffer) { + logDebug("Storing Bytes") context.parent ! ByteBufferData(bytes) } @@ -93,7 +94,7 @@ trait ActorHelper { * being pushed into Spark's memory. */ def store[T](item: T) { - println("Storing item") + logDebug("Storing item") context.parent ! SingleItemData(item) } } @@ -157,15 +158,16 @@ private[streaming] class ActorReceiver[T: ClassTag]( def receive = { case IteratorData(iterator) => - println("received iterator") + logDebug("received iterator") store(iterator.asInstanceOf[Iterator[T]]) case SingleItemData(msg) => - println("received single") + logDebug("received single") store(msg.asInstanceOf[T]) n.incrementAndGet case ByteBufferData(bytes) => + logDebug("received bytes") store(bytes) case props: Props => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index ce8316bb14891..d934b9cbfc3e8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -110,8 +110,7 @@ private[streaming] class ReceiverSupervisorImpl( ) { val blockId = optionalBlockId.getOrElse(nextBlockId) val time = System.currentTimeMillis - blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], - storageLevel, tellMaster = true) + blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true) logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata) } @@ -124,7 +123,7 @@ private[streaming] class ReceiverSupervisorImpl( ) { val blockId = optionalBlockId.getOrElse(nextBlockId) val time = System.currentTimeMillis - blockManager.put(blockId, iterator, storageLevel, tellMaster = true) + blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") reportPushedBlock(blockId, -1, optionalMetadata) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index cc4a65011dd72..952a74fd5f6de 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -383,7 +383,10 @@ class TestActor(port: Int) extends Actor with ActorHelper { def bytesToString(byteString: ByteString) = byteString.utf8String - override def preStart = IOManager(context.system).connect(new InetSocketAddress(port)) + override def preStart(): Unit = { + @deprecated("suppress compile time deprecation warning", "1.0.0") + val unit = IOManager(context.system).connect(new InetSocketAddress(port)) + } def receive = { case IO.Read(socket, bytes) => diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 03a73f92b275e..566983675bff5 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -99,9 +99,25 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } + /** Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ + def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { + try { + Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) + .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) + } catch { + case t: Throwable => + println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + Seq.empty[String] + } + } + private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ + getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -121,7 +137,8 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") + name.contains("Hive") || + name.contains("repl") } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 062f946a9fe93..62b5c3bc5f0f3 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -60,6 +60,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) private var isLastAMRetry: Boolean = true @@ -237,6 +238,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (null != sparkContext) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, resourceManager, @@ -255,10 +257,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, sparkContext.getConf) } } - } finally { - // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : - // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) } } @@ -277,13 +275,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) - ApplicationMaster.incrementAllocatorLoop(1) Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) } - } finally { - // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, - // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) } logInfo("All executors have launched.") @@ -369,7 +362,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setTrackingUrl(uiHistoryAddress) resourceManager.finishApplicationMaster(finishReq) } } @@ -411,24 +404,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } object ApplicationMaster extends Logging { - // Number of times to wait for the allocator loop to complete. - // Each loop iteration waits for 100ms, so maximum of 3 seconds. - // This is to ensure that we have reasonable number of containers before we start // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. - private val ALLOCATOR_LOOP_WAIT_COUNT = 30 private val ALLOCATE_HEARTBEAT_INTERVAL = 100 - def incrementAllocatorLoop(by: Int) { - val count = yarnAllocatorLoop.getAndAdd(by) - if (count >= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.synchronized { - // to wake threads off wait ... - yarnAllocatorLoop.notifyAll() - } - } - } - private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() def register(master: ApplicationMaster) { @@ -437,7 +416,6 @@ object ApplicationMaster extends Logging { val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null /* initialValue */) - val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) def sparkContextInitialized(sc: SparkContext): Boolean = { var modified = false @@ -472,21 +450,6 @@ object ApplicationMaster extends Logging { modified } - - /** - * Returns when we've either - * 1) received all the requested executors, - * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, - * 3) hit an error that causes us to terminate trying to get containers. - */ - def waitForInitialAllocations() { - yarnAllocatorLoop.synchronized { - while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.wait(1000L) - } - } - } - def main(argStrings: Array[String]) { SignalLogger.register(log) val args = new ApplicationMasterArguments(argStrings) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index bfdb6232f5113..184e2ad6c82cd 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -28,10 +28,10 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter import org.apache.spark.scheduler.SplitInfo import org.apache.spark.deploy.SparkHadoopUtil @@ -56,10 +56,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false + + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) val securityManager = new SecurityManager(sparkConf) - val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ @@ -81,6 +88,9 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") driverClosed = true + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x } } @@ -93,25 +103,28 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - - // Compute number of threads for akka - val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - - if (minimumMemory > 0) { - val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore + synchronized { + if (!isFinished) { + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + registered = true } } - waitForSparkMaster() - + addAmIpFilter() // Allocate all containers allocateExecutors() @@ -171,7 +184,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") + val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") + logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) .asInstanceOf[RegisterApplicationMasterRequest] appMasterRequest.setApplicationAttemptId(appAttemptId) @@ -180,10 +194,21 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appMasterRequest.setHost(Utils.localHostName()) appMasterRequest.setRpcPort(0) // What do we provide here ? Might make sense to expose something sensible later ? - appMasterRequest.setTrackingUrl("") + appMasterRequest.setTrackingUrl(appUIAddress) resourceManager.registerApplicationMaster(appMasterRequest) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) + } + private def waitForSparkMaster() { logInfo("Waiting for spark driver to be reachable.") var driverUp = false @@ -227,11 +252,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) + checkNumExecutorsFailed() Thread.sleep(100) } logInfo("All executors have launched.") - + } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } } // TODO: We might want to extend this to allocate more containers in case they die ! @@ -241,6 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating " + missingExecutorCount + @@ -266,15 +298,23 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.allocateContainers(0) } - def finishApplicationMaster(status: FinalApplicationStatus) { - - logInfo("finish ApplicationMaster with " + status) - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) - resourceManager.finishApplicationMaster(finishReq) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setDiagnostics(appMessage) + resourceManager.finishApplicationMaster(finishReq) + } + isFinished = true + } } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 556f49342977a..a1298e8f30b5c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -232,7 +232,8 @@ trait ClientBase extends Logging { if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { val setPermissions = if (destName.equals(ClientBase.APP_JAR)) true else false val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + val destFs = FileSystem.get(destPath.toUri(), conf) + distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) } else if (confKey != null) { sparkConf.set(confKey, localPath) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 718cb19f57261..e98308cdbd74e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,6 +30,9 @@ import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.SparkHadoopUtil /** @@ -132,4 +135,17 @@ object YarnSparkHadoopUtil { } } + def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = { + val eventLogDir = sc.eventLogger match { + case Some(logger) => logger.getApplicationLogDir() + case None => "" + } + val historyServerAddress = conf.get("spark.yarn.historyServer.address", "") + if (historyServerAddress != "" && eventLogDir != "") { + historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir" + } else { + "" + } + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 15e8c21aa5906..3474112ded5d7 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -37,14 +37,4 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configur val retval = YarnAllocationHandler.lookupRack(conf, host) if (retval != null) Some(retval) else None } - - override def postStartHook() { - - super.postStartHook() - // The yarn application is running, but the executor might not yet ready - // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt - // TODO It needn't after waitBackendReady - Thread.sleep(2000L) - logInfo("YarnClientClusterScheduler.postStartHook done") - } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0f9fdcfcb6510..f8fb96b312f23 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -30,8 +30,15 @@ private[spark] class YarnClientSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { + if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + minRegisteredRatio = 0.8 + ready = false + } + var client: Client = null var appId: ApplicationId = null + var checkerThread: Thread = null + var stopping: Boolean = false private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -48,6 +55,8 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort + conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) + conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( @@ -79,6 +88,7 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.runApp() waitForApp() + checkerThread = yarnApplicationStateCheckerThread() } def waitForApp() { @@ -109,7 +119,32 @@ private[spark] class YarnClientSchedulerBackend( } } + private def yarnApplicationStateCheckerThread(): Thread = { + val t = new Thread { + override def run() { + while (!stopping) { + val report = client.getApplicationReport(appId) + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED + || state == YarnApplicationState.FAILED) { + logError(s"Yarn application already ended: $state") + sc.stop() + stopping = true + } + Thread.sleep(1000L) + } + checkerThread = null + Thread.currentThread().interrupt() + } + } + t.setName("Yarn Application State Checker") + t.setDaemon(true) + t.start() + t + } + override def stop() { + stopping = true super.stop() client.stop logInfo("Stopped") diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 9ee53d797c8ea..9aeca4a637d38 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -47,14 +47,8 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) } override def postStartHook() { - val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + ApplicationMaster.sparkContextInitialized(sc) super.postStartHook() - if (sparkContextInitialized){ - ApplicationMaster.waitForInitialAllocations() - // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt - // TODO It needn't after waitBackendReady - Thread.sleep(3000L) - } logInfo("YarnClusterScheduler.postStartHook done") } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index a04b08f43cc5a..0ad1794d19538 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -27,6 +27,11 @@ private[spark] class YarnClusterSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + minRegisteredRatio = 0.8 + ready = false + } + override def start() { super.start() var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1a24ec759b546..035356d390c80 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -59,6 +59,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt( YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) private var isLastAMRetry: Boolean = true @@ -216,6 +217,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (sparkContext != null) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, amClient, @@ -234,10 +236,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, sparkContext.getConf) } } - } finally { - // In case of exceptions, etc - ensure that the loop in - // ApplicationMaster#sparkContextInitialized() breaks. - ApplicationMaster.doneWithInitialAllocations() } } @@ -254,16 +252,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() - if (iters == ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) { - ApplicationMaster.doneWithInitialAllocations() - } Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) iters += 1 } - } finally { - // In case of exceptions, etc - ensure that the loop in - // ApplicationMaster#sparkContextInitialized() breaks. - ApplicationMaster.doneWithInitialAllocations() } logInfo("All executors have launched.") } @@ -323,8 +314,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, logInfo("Unregistering ApplicationMaster with " + status) if (registered) { - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl) + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } } } @@ -365,12 +355,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } object ApplicationMaster extends Logging { - // Number of times to wait for the allocator loop to complete. - // Each loop iteration waits for 100ms, so maximum of 3 seconds. - // This is to ensure that we have reasonable number of containers before we start // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. - private val ALLOCATOR_LOOP_WAIT_COUNT = 30 private val ALLOCATE_HEARTBEAT_INTERVAL = 100 private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() @@ -378,20 +364,6 @@ object ApplicationMaster extends Logging { val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) - // Variable used to notify the YarnClusterScheduler that it should stop waiting - // for the initial set of executors to be started and get on with its business. - val doneWithInitialAllocationsMonitor = new Object() - - @volatile var isDoneWithInitialAllocations = false - - def doneWithInitialAllocations() { - isDoneWithInitialAllocations = true - doneWithInitialAllocationsMonitor.synchronized { - // to wake threads off wait ... - doneWithInitialAllocationsMonitor.notifyAll() - } - } - def register(master: ApplicationMaster) { applicationMasters.add(master) } @@ -434,20 +406,6 @@ object ApplicationMaster extends Logging { modified } - /** - * Returns when we've either - * 1) received all the requested executors, - * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, - * 3) hit an error that causes us to terminate trying to get containers. - */ - def waitForInitialAllocations() { - doneWithInitialAllocationsMonitor.synchronized { - while (!isDoneWithInitialAllocations) { - doneWithInitialAllocationsMonitor.wait(1000L) - } - } - } - def getApplicationAttemptId(): ApplicationAttemptId = { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) val containerId = ConverterUtils.toContainerId(containerIdString) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index f71ad036ce0f2..fc7b8320d734d 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -19,22 +19,21 @@ package org.apache.spark.deploy.yarn import java.net.Socket import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.yarn.webapp.util.WebAppUtils /** * An application master that allocates executors on behalf of a driver that is running outside @@ -55,10 +54,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false private var amClient: AMRMClient[ContainerRequest] = _ + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 @@ -82,6 +87,9 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") driverClosed = true + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x } } @@ -96,9 +104,15 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp amClient.start() appAttemptId = ApplicationMaster.getApplicationAttemptId() - registerApplicationMaster() + synchronized { + if (!isFinished) { + registerApplicationMaster() + registered = true + } + } waitForSparkMaster() + addAmIpFilter() // Allocate all containers allocateExecutors() @@ -142,9 +156,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") - // TODO: Find out client's Spark UI address and fill in here? - amClient.registerApplicationMaster(Utils.localHostName(), 0, "") + val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") + logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") + amClient.registerApplicationMaster(Utils.localHostName(), 0, appUIAddress) + } + + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) } private def waitForSparkMaster() { @@ -193,6 +218,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(100) @@ -211,12 +237,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } + } + private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") yarnAllocator.allocateResources() @@ -231,10 +265,18 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp t } - def finishApplicationMaster(status: FinalApplicationStatus) { - logInfo("Unregistering ApplicationMaster with " + status) - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") + amClient.unregisterApplicationMaster(status, appMessage, trackingUrl) + } + isFinished = true + } } }
Property NameDefaultMeaning
spark.deploy.retainedApplications200 + The maximum number of completed applications to display. Older applications will be dropped from the UI to maintain this limit.
+
spark.deploy.retainedDrivers200 + The maximum number of completed drivers to display. Older drivers will be dropped from the UI to maintain this limit.
+
spark.deploy.spreadOut true