From 59e84dc5435dea27ad7efb82ff3f1c5a28fa8ec5 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Fri, 13 Nov 2015 11:53:16 -0500 Subject: [PATCH 001/149] [SPARK-6990] Remove whitespace after '>' Closing '>' and method name shouldn't have whitespace between them, according to http://checkstyle.sourceforge.net/config_whitespace.html#GenericWhitespace --- .../org/apache/spark/streaming/JavaTrackStateByKeySuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index eac4cdd14a683..89d0bb7b617e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -95,7 +95,7 @@ public Double call(Optional one, State state) { JavaTrackStateDStream stateDstream2 = wordsDstream.trackStateByKey( - StateSpec. function(trackStateFunc2) + StateSpec.function(trackStateFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) From c6bab96d3f958d413373789e6e790b98d4dd195a Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Tue, 10 Nov 2015 11:56:46 -0500 Subject: [PATCH 002/149] [SPARK-6990] Add Java linting script Invoke Checkstyle and print any errors to the console, failing the step. Use Google's style rules modified according to https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide Some important checks are disabled (see TODOs in checkstyle.xml) due to multiple violations being present in the codebase. --- checkstyle.xml | 159 +++++++++++++++++++++++++++++++ dev/lint-java | 34 +++++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 ++ dev/sparktestsupport/__init__.py | 1 + pom.xml | 24 +++++ 6 files changed, 226 insertions(+) create mode 100644 checkstyle.xml create mode 100755 dev/lint-java diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 0000000000000..403c5356f7dd8 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,159 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 0000000000000..55b6b0c348416 --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +$SCRIPT_DIR/../build/mvn -Pkinesis-asl -Phive -Phive-thriftserver checkstyle:check > checkstyle.txt +$SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phadoop-2.2 checkstyle:check >> checkstyle.txt + +ERRORS=$(grep ERROR checkstyle.txt) +rm checkstyle.txt + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e189..622e5b6a17655 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,6 +119,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb0697192..e7e10f1d8c725 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,6 +198,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -522,6 +527,8 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2f..0e8032d13341e 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,6 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, "BLOCK_TIMEOUT": 124 } diff --git a/pom.xml b/pom.xml index 01afa80617891..950421e5d956d 100644 --- a/pom.xml +++ b/pom.xml @@ -2246,6 +2246,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + false + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins From 20de73e85b4c8fe4cb4edebd0d564964c8525a0b Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Tue, 10 Nov 2015 11:57:20 -0500 Subject: [PATCH 003/149] [SPARK-6990] Fix some Checkstyle warnings --- .../apache/spark/util/collection/TimSort.java | 118 ++++++++++++------ .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../spark/examples/mllib/JavaLDAExample.java | 3 +- .../mllib/JavaRecommendationExample.java | 2 +- .../streaming/JavaSqlNetworkWordCount.java | 4 +- .../shuffle/ExternalShuffleBlockResolver.java | 2 +- .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/types/SQLUserDefinedType.java | 2 +- .../spark/streaming/util/WriteAheadLog.java | 10 +- .../apache/spark/tags/ExtendedHiveTest.java | 1 + .../apache/spark/tags/ExtendedYarnTest.java | 1 + .../apache/spark/unsafe/types/UTF8String.java | 8 +- 13 files changed, 98 insertions(+), 59 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 40b5fb7fe4b49..0d476dc0748fd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,8 +120,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) + if (nRemaining < 2) { return; // Arrays of size 0 and 1 are always sorted + } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -184,8 +185,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) + if (start == lo) { start++; + } K key0 = s.newKey(); K key1 = s.newKey(); @@ -206,10 +208,11 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { right = mid; - else + } else { left = mid + 1; + } } assert left == right; @@ -260,20 +263,23 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) + if (runHi == hi) { return 1; + } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { runHi++; + } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { runHi++; + } } return runHi - lo; @@ -445,8 +451,9 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) + if (runLen[n - 1] < runLen[n + 1]) { n--; + } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -461,8 +468,9 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) + if (n > 0 && runLen[n - 1] < runLen[n + 1]) { n--; + } mergeAt(n); } } @@ -508,8 +516,9 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) + if (len1 == 0) { return; + } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -517,14 +526,16 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) + if (len2 == 0) { return; + } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) + if (len1 <= len2) { mergeLo(base1, len1, base2, len2); - else + } else { mergeHi(base1, len1, base2, len2); + } } /** @@ -557,11 +568,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base lastOfs += hint; @@ -572,11 +585,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base int tmp = lastOfs; @@ -594,10 +609,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { lastOfs = m + 1; // a[base + m] < key - else + } else { ofs = m; // key <= a[base + m] + } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -629,11 +645,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b int tmp = lastOfs; @@ -645,11 +663,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b lastOfs += hint; @@ -666,10 +686,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { ofs = m; // key < a[b + m] - else + } else { lastOfs = m + 1; // a[b + m] <= key + } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -735,14 +756,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) + if (--len2 == 0) { break outer; + } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) + if (--len1 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -759,12 +782,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) // len1 == 1 || len1 == 0 + if (len1 <= 1) { // len1 == 1 || len1 == 0 break outer; + } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) + if (--len2 == 0) { break outer; + } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -772,16 +797,19 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) + if (len2 == 0) { break outer; + } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) + if (--len1 == 1) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -857,14 +885,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) + if (--len1 == 0) { break outer; + } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) + if (--len2 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -881,12 +911,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) + if (len1 == 0) { break outer; + } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) + if (--len2 == 1) { break outer; + } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -894,16 +926,19 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) // len2 == 1 || len2 == 0 + if (len2 <= 1) { // len2 == 1 || len2 == 0 break outer; + } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) + if (--len1 == 0) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -943,10 +978,11 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) // Not bloody likely! + if (newSize < 0) { // Not bloody likely! newSize = minCapacity; - else + } else { newSize = Math.min(newSize, aLength >>> 1); + } tmp = s.allocate(newSize); tmpLength = newSize; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f46..9ce5fd2a1c8ef 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -180,7 +180,7 @@ private SortedIterator( this.sortBuffer = sortBuffer; } - public SortedIterator clone () { + public SortedIterator clone() { SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); iter.position = position; iter.baseObject = baseObject; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d479..ea83e8fef9eb9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc4974..de8e739ac9256 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index 1065fde953b96..c179e7578cdfa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -29,7 +29,7 @@ // $example off$ public class JavaRecommendationExample { - public static void main(String args[]) { + public static void main(String[] args) { // $example on$ SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); JavaSparkContext jsc = new JavaSparkContext(conf); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb57..3515d7be45d37 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac769..e5cb68c8a4dbb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -419,7 +419,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); public final int major; public final int minor; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3986d6e18f770..352002b3499a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b36..1e4e5ede8cc11 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c2..2803cad8095dd 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4e..83279e5e93c0e 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf0..108300168e173 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b7aecb5102ba6..1d3382892e7e1 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -895,9 +895,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -960,7 +960,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; From 6a2c1439eb2ed56de25fc3f45f448664a9162329 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 09:52:19 -0500 Subject: [PATCH 004/149] [SPARK-6990] Suppress Checkstyle for TimSort The code was copied from a third-party source and needs to be in sync with that, so we shouldn't make our own modifications to it. The file contains some style violations, so suppress the checks. --- checkstyle-suppressions.xml | 33 +++++ checkstyle.xml | 5 + .../apache/spark/util/collection/TimSort.java | 118 ++++++------------ 3 files changed, 79 insertions(+), 77 deletions(-) create mode 100644 checkstyle-suppressions.xml diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 0000000000000..9242be3d0357a --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml index 403c5356f7dd8..b273e7be27a11 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -47,6 +47,11 @@ + + + + + diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 0d476dc0748fd..40b5fb7fe4b49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,9 +120,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) { + if (nRemaining < 2) return; // Arrays of size 0 and 1 are always sorted - } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -185,9 +184,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) { + if (start == lo) start++; - } K key0 = s.newKey(); K key1 = s.newKey(); @@ -208,11 +206,10 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; - } else { + else left = mid + 1; - } } assert left == right; @@ -263,23 +260,20 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) { + if (runHi == hi) return 1; - } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) runHi++; - } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; - } } return runHi - lo; @@ -451,9 +445,8 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) { + if (runLen[n - 1] < runLen[n + 1]) n--; - } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -468,9 +461,8 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) { + if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; - } mergeAt(n); } } @@ -516,9 +508,8 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) { + if (len1 == 0) return; - } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -526,16 +517,14 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) { + if (len2 == 0) return; - } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) { + if (len1 <= len2) mergeLo(base1, len1, base2, len2); - } else { + else mergeHi(base1, len1, base2, len2); - } } /** @@ -568,13 +557,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base lastOfs += hint; @@ -585,13 +572,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base int tmp = lastOfs; @@ -609,11 +594,10 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) lastOfs = m + 1; // a[base + m] < key - } else { + else ofs = m; // key <= a[base + m] - } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -645,13 +629,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b int tmp = lastOfs; @@ -663,13 +645,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b lastOfs += hint; @@ -686,11 +666,10 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) ofs = m; // key < a[b + m] - } else { + else lastOfs = m + 1; // a[b + m] <= key - } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -756,16 +735,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) { + if (--len2 == 0) break outer; - } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) { + if (--len1 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -782,14 +759,12 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) { // len1 == 1 || len1 == 0 + if (len1 <= 1) // len1 == 1 || len1 == 0 break outer; - } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) { + if (--len2 == 0) break outer; - } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -797,19 +772,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) { + if (len2 == 0) break outer; - } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) { + if (--len1 == 1) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -885,16 +857,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) { + if (--len1 == 0) break outer; - } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) { + if (--len2 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -911,14 +881,12 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) { + if (len1 == 0) break outer; - } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) { + if (--len2 == 1) break outer; - } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -926,19 +894,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) { // len2 == 1 || len2 == 0 + if (len2 <= 1) // len2 == 1 || len2 == 0 break outer; - } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) { + if (--len1 == 0) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -978,11 +943,10 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) { // Not bloody likely! + if (newSize < 0) // Not bloody likely! newSize = minCapacity; - } else { + else newSize = Math.min(newSize, aLength >>> 1); - } tmp = s.allocate(newSize); tmpLength = newSize; From 67ffb934b259ab5dae4b7fb8a03fd71e77cfbda2 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 09:58:41 -0500 Subject: [PATCH 005/149] [SPARK-6990] Fix some Checkstyle issues in tests --- .../map/AbstractBytesToBytesMapSuite.java | 4 +- .../ml/feature/JavaStringIndexerSuite.java | 6 +- .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../apache/spark/sql/hive/test/Complex.java | 86 +++++++++++++------ .../apache/spark/streaming/JavaAPISuite.java | 4 +- 6 files changed, 69 insertions(+), 35 deletions(-) diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d87a1d2a56d99..a5c583f9f2844 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -356,8 +356,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c342..b2df79ba74feb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -58,7 +58,7 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() @@ -67,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46c..225a216270b3b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e6810..89e8b8c01c200 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -59,7 +59,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb9327..4ef1f276d1bbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e4..7d21cdb6d9888 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1293,12 +1293,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), From 6f42a36e90b2b82577621a6e54a6a4fdd5b55e4f Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 10:01:20 -0500 Subject: [PATCH 006/149] [SPARK-6990] Enable Checkstyle for tests --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 950421e5d956d..af48f7c26798f 100644 --- a/pom.xml +++ b/pom.xml @@ -2253,7 +2253,7 @@ false false - false + true false ${basedir}/src/main/java ${basedir}/src/test/java From 677e228f429029072cdc87770cb584de34cea63e Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 10:01:48 -0500 Subject: [PATCH 007/149] [SPARK-6990] Enable FallThrough check in Checkstyle This makes sure all case statements end with a break. See http://checkstyle.sourceforge.net/config_coding.html#FallThrough --- checkstyle.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/checkstyle.xml b/checkstyle.xml index b273e7be27a11..a63b1b7797201 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -88,6 +88,7 @@ + From a24477996e936b0861819ffb420f763f80f0b1da Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 13 Nov 2015 10:31:17 -0800 Subject: [PATCH 008/149] [SPARK-11690][PYSPARK] Add pivot to python api This PR adds pivot to the python api of GroupedData with the same syntax as Scala/Java. Author: Andrew Ray Closes #9653 from aray/sql-pivot-python. --- python/pyspark/sql/group.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 71c0bccc5eeff..227f40bc3cf53 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -17,7 +17,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql.column import Column, _to_seq +from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * @@ -167,6 +167,23 @@ def sum(self, *cols): [Row(sum(age)=7, sum(height)=165)] """ + @since(1.6) + def pivot(self, pivot_col, *values): + """Pivots a column of the current DataFrame and preform the specified aggregation. + + :param pivot_col: Column to pivot + :param values: Optional list of values of pivotColumn that will be translated to columns in + the output data frame. If values are not provided the method with do an immediate call + to .distinct() on the pivot column. + >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() + [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + """ + jgd = self._jdf.pivot(_to_java_column(pivot_col), + _to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) + return GroupedData(jgd, self.sql_ctx) + def _test(): import doctest @@ -182,6 +199,11 @@ def _test(): StructField('name', StringType())])) globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), + Row(course="Java", year=2012, earnings=20000), + Row(course="dotNET", year=2012, earnings=5000), + Row(course="dotNET", year=2013, earnings=48000), + Row(course="Java", year=2013, earnings=30000)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, From 23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Nov 2015 11:13:09 -0800 Subject: [PATCH 009/149] [SPARK-11654][SQL][FOLLOW-UP] fix some mistakes and clean up * rename `AppendColumn` to `AppendColumns` to be consistent with the physical plan name. * clean up stale comments. * always pass in resolved encoder to `TypedColumn.withInputType`(test added) * enable a mistakenly disabled java test. Author: Wenchen Fan Closes #9688 from cloud-fan/follow. --- .../sql/catalyst/plans/logical/basicOperators.scala | 8 ++++---- .../src/main/scala/org/apache/spark/sql/Column.scala | 3 ++- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 +++------- .../scala/org/apache/spark/sql/GroupedDataset.scala | 4 ++-- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../test/org/apache/spark/sql/JavaDatasetSuite.java | 1 + .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 4 ++++ 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d9f046efce0bf..e2b97b27a6c2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -482,13 +482,13 @@ case class MapPartitions[T, U]( } /** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumn { +object AppendColumns { def apply[T, U : Encoder]( func: T => U, tEncoder: ExpressionEncoder[T], - child: LogicalPlan): AppendColumn[T, U] = { + child: LogicalPlan): AppendColumns[T, U] = { val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child) + new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child) } } @@ -497,7 +497,7 @@ object AppendColumn { * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to * decode/encode from the JVM object representation expected by `func.` */ -case class AppendColumn[T, U]( +case class AppendColumns[T, U]( func: T => U, tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 929224460dc09..82e9cd7f50a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -58,10 +58,11 @@ class TypedColumn[-T, U]( private[sql] def withInputType( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { + val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] new TypedColumn[T, U] (expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => ta.copy( - aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]), + aEncoder = Some(boundEncoder), children = schema) }, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b930e4661c1a6..4cc3aa2465f2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -299,7 +299,7 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan) + val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( @@ -364,13 +364,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - // We use an unbound encoder since the expression will make up its own schema. - // TODO: This probably doesn't work if we are relying on reordering of the input class fields. new Dataset[U1]( sqlContext, Project( c1.withInputType( - resolvedTEncoder.bind(queryExecution.analyzed.output), + resolvedTEncoder, queryExecution.analyzed.output).named :: Nil, logicalPlan)) } @@ -382,10 +380,8 @@ class Dataset[T] private[sql]( */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) - // We use an unbound encoder since the expression will make up its own schema. - // TODO: This probably doesn't work if we are relying on reordering of the input class fields. val namedColumns = - columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named) + columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index ae1272ae531fb..9c16940707de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -89,7 +89,7 @@ class GroupedDataset[K, T] private[sql]( } /** - * Applies the given function to each group of data. For each unique group, the function will + * Applies the given function to each group of data. For each unique group, the function will * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. @@ -162,7 +162,7 @@ class GroupedDataset[K, T] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named) + _.withInputType(resolvedTEncoder, dataAttributes).named) val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a99ae4674bb12..67201a2c191cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -321,7 +321,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, tEnc, uEnc, output, child) => execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 46169ca07d715..eb6fa1e72e27b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -157,6 +157,7 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(6, reduced); } + @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 20896efdfec16..46f9f077fe7f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -161,6 +161,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.select(ClassInputAgg.toColumn), 1) + checkAnswer( + ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), + (1.0, 1)) + checkAnswer( ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) From d7b2b97ad67f9700fb8c13422c399f2edb72f770 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Nov 2015 11:25:33 -0800 Subject: [PATCH 010/149] [SPARK-11727][SQL] Split ExpressionEncoder into FlatEncoder and ProductEncoder also add more tests for encoders, and fix bugs that I found: * when convert array to catalyst array, we can only skip element conversion for native types(e.g. int, long, boolean), not `AtomicType`(String is AtomicType but we need to convert it) * we should also handle scala `BigDecimal` when convert from catalyst `Decimal`. * complex map type should be supported other issues that still in investigation: * encode java `BigDecimal` and decode it back, seems we will loss precision info. * when encode case class that defined inside a object, `ClassNotFound` exception will be thrown. I'll remove unused code in a follow-up PR. Author: Wenchen Fan Closes #9693 from cloud-fan/split. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 ++ .../catalyst/encoders/ProductEncoder.scala | 452 ++++++++++++++++++ .../sql/catalyst/encoders/RowEncoder.scala | 58 +-- .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../sql/catalyst/util/GenericArrayData.scala | 2 +- .../encoders/ExpressionEncoderSuite.scala | 259 ++-------- .../catalyst/encoders/FlatEncoderSuite.scala | 74 +++ .../encoders/ProductEncoderSuite.scala | 123 +++++ .../org/apache/spark/sql/GroupedDataset.scala | 7 +- .../org/apache/spark/sql/SQLImplicits.scala | 22 +- .../org/apache/spark/sql/functions.scala | 4 +- 12 files changed, 766 insertions(+), 289 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6d822261b050a..0b3dd351e38e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -75,7 +75,7 @@ trait ScalaReflection { * * @see SPARK-5281 */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala new file mode 100644 index 0000000000000..6d307ab13a9fc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} +import org.apache.spark.sql.catalyst.ScalaReflection + +object FlatEncoder { + import ScalaReflection.schemaFor + import ScalaReflection.dataTypeFor + + def apply[T : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) + + val input = BoundReference(0, dataTypeFor(tpe), nullable = true) + val toRowExpression = CreateNamedStruct( + Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) + val fromRowExpression = ProductEncoder.constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = true, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala new file mode 100644 index 0000000000000..414adb21168ed --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.Utils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} + +import scala.reflect.ClassTag + +object ProductEncoder { + import ScalaReflection.universe._ + import ScalaReflection.localTypeOf + import ScalaReflection.dataTypeFor + import ScalaReflection.Schema + import ScalaReflection.schemaFor + import ScalaReflection.arrayClassFor + + def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] + val fromRowExpression = constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + def extractorFor( + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className: String = optType.erasure.typeSymbol.asClass.fullName + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, schemaFor(optType).dataType), + extractorFor(unwrapped, optType)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateNamedStruct(params.head.flatMap { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + } + } + } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (RowEncoder.isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } + + def constructorFor( + tpe: `Type`, + path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + WrapOption(null, constructorFor(optType, path)) + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.zipWithIndex.map { case (p, i) => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = schemaFor(fieldType).dataType + + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0b42130a013b2..e0be896bb3548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,9 +119,17 @@ object RowEncoder { CreateStruct(convertedFields) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => dt + FloatType | DoubleType | BinaryType => true + case _ => false + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case _ if isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -137,13 +145,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + constructorFor(BoundReference(i, f.dataType, f.nullable)) ) } CreateExternalRow(fields) } - private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { + private def constructorFor(input: Expression): Expression = input.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input @@ -170,7 +178,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_, et), input, et), + MapObjects(constructorFor, input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -181,10 +189,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData, @@ -197,42 +205,8 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(getField(input, i, f.dataType), f.dataType)) + constructorFor(GetInternalRowField(input, i, f.dataType))) } CreateExternalRow(convertedFields) } - - private def getField( - row: Expression, - ordinal: Int, - dataType: DataType): Expression = dataType match { - case BooleanType => - Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) - case ByteType => - Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) - case ShortType => - Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) - case IntegerType | DateType => - Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) - case LongType | TimestampType => - Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) - case FloatType => - Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) - case DoubleType => - Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) - case t: DecimalType => - Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) - case StringType => - Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) - case BinaryType => - Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) - case CalendarIntervalType => - Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) - case t: StructType => - Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) - case _: ArrayType => - Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) - case _: MapType => - Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f5fff90e5a542..deff8a5378b92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -110,7 +110,7 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - var indexOfGMT = s.indexOf("GMT"); + val indexOfGMT = s.indexOf("GMT") if (indexOfGMT != -1) { // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) val s0 = s.substring(0, indexOfGMT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index e9bf7b33e35be..96588bb5dc1bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { - def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(seq.toArray) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e0..9fe64b4cf10e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,232 +17,27 @@ package org.apache.spark.sql.catalyst.encoders -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe._ +import java.util.Arrays import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{StructField, ArrayType} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ExpressionEncoderSuite extends SparkFunSuite { - - encodeDecodeTest(1) - encodeDecodeTest(1L) - encodeDecodeTest(1.toDouble) - encodeDecodeTest(1.toFloat) - encodeDecodeTest(true) - encodeDecodeTest(false) - encodeDecodeTest(1.toShort) - encodeDecodeTest(1.toByte) - encodeDecodeTest("hello") - - encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - // TODO: Support creating specific subclasses of Seq. - ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } - - encodeDecodeTest( - OptionalData( - Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) - - encodeDecodeTest( - BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - encodeDecodeTest( - BoxedData(null, null, null, null, null, null, null)) - - encodeDecodeTest( - RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - encodeDecodeTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) - - encodeDecodeTest(("Seq[(String, String)]", - Seq(("a", "b")))) - encodeDecodeTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - encodeDecodeTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - encodeDecodeTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - encodeDecodeTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - encodeDecodeTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - // TODO: Decoding/encoding of complex maps. - ignore("complex maps") { - encodeDecodeTest(("Map[Int, (String, String)]", - Map(1 ->("a", "b")))) - } - - encodeDecodeTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - encodeDecodeTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - encodeDecodeTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - encodeDecodeTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - encodeDecodeTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array((1, 2))))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array(Array((1, 2)))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", - Array(Array(Array(Array((1, 2))))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", - Array(Array(Array(Array(Array((1, 2)))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - - encodeDecodeTestCustom(("Array[Array[Integer]]", - Array(Array[Integer](1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(Array(1))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Int]]]", - Array(Array(Array(Array(1)))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", - Array(Array(Array(Array(Array(1))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - encodeDecodeTest(("Array[Byte] null", - null: Array[Byte])) - encodeDecodeTestCustom(("Array[Byte]", - Array[Byte](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Int] null", - null: Array[Int])) - encodeDecodeTestCustom(("Array[Int]", - Array[Int](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Long] null", - null: Array[Long])) - encodeDecodeTestCustom(("Array[Long]", - Array[Long](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Double] null", - null: Array[Double])) - encodeDecodeTestCustom(("Array[Double]", - Array[Double](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Float] null", - null: Array[Float])) - encodeDecodeTestCustom(("Array[Float]", - Array[Float](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Boolean] null", - null: Array[Boolean])) - encodeDecodeTestCustom(("Array[Boolean]", - Array[Boolean](true, false))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Short] null", - null: Array[Short])) - encodeDecodeTestCustom(("Array[Short]", - Array[Short](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTestCustom(("java.sql.Timestamp", - new java.sql.Timestamp(1))) - { (l, r) => l._2.toString == r._2.toString } - - encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) - { (l, r) => l._2.toString == r._2.toString } - - /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T : TypeTag](inputData: T) = - encodeDecodeTestCustom[T](inputData)((l, r) => l == r) - - /** - * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it - * matches the original. - */ - protected def encodeDecodeTestCustom[T : TypeTag]( - inputData: T)( - c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { - val encoder = try ExpressionEncoder[T]() catch { - case e: Exception => - fail(s"Exception thrown generating encoder", e) - } - val convertedData = encoder.toRow(inputData) +import org.apache.spark.sql.types.ArrayType + +abstract class ExpressionEncoderSuite extends SparkFunSuite { + protected def encodeDecodeTest[T]( + input: T, + encoder: ExpressionEncoder[T], + testName: String): Unit = { + test(s"encode/decode for $testName: $input") { + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema).bind(schema) - val convertedBack = try boundEncoder.fromRow(convertedData) catch { + val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( s"""Exception thrown while decoding - |Converted: $convertedData + |Converted: $row |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | @@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite { """.stripMargin, e) } - if (!c(inputData, convertedBack)) { + val isCorrect = (input, convertedBack) match { + case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) + case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) + case (b1: Array[Array[_]], b2: Array[Array[_]]) => + Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: Array[_], b2: Array[_]) => + Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case _ => input == convertedBack + } + + if (!isCorrect) { val types = convertedBack match { case c: Product => c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") case other => other.getClass.getName } - val encodedData = try { - convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq + row.toSeq(encoder.schema).zip(schema).map { + case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + a.toArray[Any](et).toSeq case (other, _) => other }.mkString("[", ",", "]") @@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { fail( s"""Encoded/Decoded data does not match input data | - |in: $inputData + |in: $input |out: $convertedBack |types: $types | @@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Extract Expressions: - |$boundEncoder + |fromRow Expressions: + |${boundEncoder.fromRowExpression.treeString} """.stripMargin) - } } - + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala new file mode 100644 index 0000000000000..55821c4370684 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.sql.{Date, Timestamp} + +class FlatEncoderSuite extends ExpressionEncoderSuite { + encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") + encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") + encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") + encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") + encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") + encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") + encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") + type JDecimal = java.math.BigDecimal + // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") + + encodeDecodeTest("hello", FlatEncoder[String], "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") + + encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") + encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") + encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), + FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + FlatEncoder[Seq[Seq[String]]], "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") + encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") + encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") + encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") + encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), + FlatEncoder[Array[Array[Int]]], "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + FlatEncoder[Array[Array[String]]], "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), + FlatEncoder[Map[Int, Map[String, Int]]], "map of map") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala new file mode 100644 index 0000000000000..fda978e7055ea --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} + +case class RepeatedStruct(s: Seq[PrimitiveData]) + +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +class ProductEncoderSuite extends ExpressionEncoderSuite { + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + private def productTest[T <: Product : TypeTag](input: T): Unit = { + encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 9c16940707de9..ebcf4c8bfe7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -56,9 +56,6 @@ class GroupedDataset[K, T] private[sql]( private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) - /** Encoders for built in aggregations. */ - private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -211,7 +208,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6da46a5f7ef9a..8471eea1b7d9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] + implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] + implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] + implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] + implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] + implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] + implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] + implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + /** + * Creates a [[Dataset]] from an RDD. + * @since 1.6.0 + */ implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 53cc6e0cda110..95158de710acf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.FlatEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long](flat = true)) + count(Column(columnName)).as(FlatEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From 2d2411faa2dd1b7312c4277b2dd9e5678195cfbb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 13 Nov 2015 13:09:28 -0800 Subject: [PATCH 011/149] [SPARK-11672][ML] Set active SQLContext in MLlibTestSparkContext.beforeAll Still saw some error messages caused by `SQLContext.getOrCreate`: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3997/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.3,label=spark-test/testReport/junit/org.apache.spark.ml.util/JavaDefaultReadWriteSuite/testDefaultReadWrite/ This PR sets the active SQLContext in beforeAll, which is not automatically set in `new SQLContext`. This makes `SQLContext.getOrCreate` return the right SQLContext. cc: yhuai Author: Xiangrui Meng Closes #9694 from mengxr/SPARK-11672.3. --- .../main/scala/org/apache/spark/ml/util/ReadWrite.scala | 7 +++++-- .../apache/spark/mllib/util/MLlibTestSparkContext.scala | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 85f888c9f2f67..ca896ed6106c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -48,8 +48,11 @@ private[util] sealed trait BaseReadWrite { /** * Returns the user-specified SQL context or the default. */ - protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { - SQLContext.getOrCreate(SparkContext.getOrCreate()) + protected final def sqlContext: SQLContext = { + if (optionSQLContext.isEmpty) { + optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + } + optionSQLContext.get } /** Returns the [[SparkContext]] underlying [[sqlContext]] */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 998ee48186558..378139593b26f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -34,6 +34,7 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => sc = new SparkContext(conf) SQLContext.clearActive() sqlContext = new SQLContext(sc) + SQLContext.setActive(sqlContext) } override def afterAll() { From 912b94363bb113ab14024a45f17a4d2b82a09e66 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 13 Nov 2015 13:14:25 -0800 Subject: [PATCH 012/149] [SPARK-11336] Add links to example codes https://issues.apache.org/jira/browse/SPARK-11336 mengxr I add a hyperlink of Spark on Github and a hint of their existences in Spark code repo in each code example. I remove the config key for changing the example code dir, since we assume all examples should be in spark/examples. The hyperlink, though we cannot use it now, since the Spark v1.6.0 has not been released yet, can be used after the release. So it is not a problem. I add some screen shots, so you can get an instant feeling. screen shot 2015-10-27 at 10 47 18 pm screen shot 2015-10-27 at 10 47 31 pm Author: Xusen Yin Closes #9320 from yinxusen/SPARK-11336. --- docs/_plugins/include_example.rb | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ee63a5ac69df..549f81fe1b1bc 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -28,7 +28,7 @@ def initialize(tag_name, markup, tokens) def render(context) site = context.registers[:site] - config_dir = (site.config['code_dir'] || '../examples/src/main').sub(/^\//,'') + config_dir = '../examples/src/main' @code_dir = File.join(site.source, config_dir) clean_markup = @markup.strip @@ -38,7 +38,12 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - Pygments.highlight(code, :lexer => @lang) + rendered_code = Pygments.highlight(code, :lexer => @lang) + + hint = "
Find full example code at " \ + "\"examples/src/main/#{clean_markup}\" in the Spark repo.
" + + rendered_code + hint end # Trim the code block so as to have the same indention, regardless of their positions in the From bdfbc1dcaf121a1a1239857adcf54cdfe82c26dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 13 Nov 2015 13:19:04 -0800 Subject: [PATCH 013/149] [MINOR][ML] remove MLlibTestsSparkContext from ImpuritySuite ImpuritySuite doesn't need SparkContext. Author: Xiangrui Meng Closes #9698 from mengxr/remove-mllib-test-context-in-impurity-suite. --- .../test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 49aff21fe7914..14152cdd63bc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} -import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { From c939c70ac1ab6a26d9fda0a99c4e837f7e5a7935 Mon Sep 17 00:00:00 2001 From: nitin goyal Date: Fri, 13 Nov 2015 18:09:08 -0800 Subject: [PATCH 014/149] [SPARK-7970] Skip closure cleaning for SQL operations Also introduces new spark private API in RDD.scala with name 'mapPartitionsInternal' which doesn't closure cleans the RDD elements. Author: nitin goyal Author: nitin.goyal Closes #9253 from nitin2goyal/master. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 18 ++++++++++++++++++ .../columnar/InMemoryColumnarTableScan.scala | 4 ++-- .../apache/spark/sql/execution/Exchange.scala | 6 +++--- .../apache/spark/sql/execution/Generate.scala | 4 ++-- .../aggregate/SortBasedAggregate.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 16 ++++++++-------- .../joins/BroadcastLeftSemiJoinHash.scala | 4 ++-- .../sql/execution/joins/CartesianProduct.scala | 2 +- .../org/apache/spark/sql/execution/sort.scala | 2 +- 9 files changed, 38 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 800ef53cbef07..2aeb5eeaad32c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -705,6 +705,24 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } + /** + * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a + * performance API to be used carefully only if we are sure that the RDD elements are + * serializable and don't require closure cleaning. + * + * @param preservesPartitioning indicates whether the input function preserves the partitioner, + * which should be `false` unless this is a pair RDD and the input function doesn't modify + * the keys. + */ + private[spark] def mapPartitionsInternal[U: ClassTag]( + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter), + preservesPartitioning) + } + /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 7eb1ad7cd8198..2cface61e59c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -125,7 +125,7 @@ private[sql] case class InMemoryRelation( private def buildBuffers(): Unit = { val output = child.output - val cached = child.execute().mapPartitions { rowIterator => + val cached = child.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -292,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( val relOutput = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitions { cachedBatchIterator => + buffers.mapPartitionsInternal { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index bc252d98e7144..a161cf0a3185b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -168,7 +168,7 @@ case class Exchange( case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = rdd.mapPartitions { iter => + val rddForSampling = rdd.mapPartitionsInternal { iter => val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } @@ -200,12 +200,12 @@ case class Exchange( } val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { if (needToCopyObjectsBeforeShuffle(part, serializer)) { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 78e33d9f233a6..54b8cb58285c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -59,7 +59,7 @@ case class Generate( protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow @@ -79,7 +79,7 @@ case class Generate( } } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.flatMap(row => boundGenerator.eval(row)) ++ LazyIterator(() => boundGenerator.terminate()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c8ccbb933df61..ee982453c3287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -69,7 +69,7 @@ case class SortBasedAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ed82c9a6a3770..07925c62cd386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -43,7 +43,7 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) iter.map { row => @@ -67,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => numInputRows += 1 @@ -161,11 +161,11 @@ case class Limit(limit: Int, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.take(limit).map(row => (false, row.copy())) } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val mutablePair = new MutablePair[Boolean, InternalRow]() iter.take(limit).map(row => mutablePair.update(false, row)) } @@ -173,7 +173,7 @@ case class Limit(limit: Int, child: SparkPlan) val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitions(_.take(limit).map(_._2)) + shuffled.mapPartitionsInternal(_.take(limit).map(_._2)) } } @@ -294,7 +294,7 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) } @@ -318,7 +318,7 @@ case class AppendColumns[T, U]( override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) iter.map { row => @@ -350,7 +350,7 @@ case class MapGroups[K, T, U]( Seq(groupingAttributes.map(SortOrder(_, Ascending))) override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) val groupKeyEncoder = kEncoder.bind(groupingAttributes) val groupDataEncoder = tEncoder.bind(child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index c5cd6a2fd6372..004407b2e6925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -54,7 +54,7 @@ case class BroadcastLeftSemiJoinHash( val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) - left.execute().mapPartitions { streamIter => + left.execute().mapPartitionsInternal { streamIter => hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { @@ -62,7 +62,7 @@ case class BroadcastLeftSemiJoinHash( HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) - left.execute().mapPartitions { streamIter => + left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 0243e196dbc37..f467519b802a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -46,7 +46,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod row.copy() } - leftResults.cartesian(rightResults).mapPartitions { iter => + leftResults.cartesian(rightResults).mapPartitionsInternal { iter => val joinedRow = new JoinedRow iter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 47fe70ab154ec..52ef00ef5b283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -47,7 +47,7 @@ case class Sort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => + child.execute().mapPartitionsInternal( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( TaskContext.get(), ordering = Some(ordering)) From 139c15b624c88b376ffdd05d78795295c8c4fc17 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 14 Nov 2015 18:36:01 +0800 Subject: [PATCH 015/149] [SPARK-11694][SQL] Parquet logical types are not being tested properly All the physical types are properly tested at `ParquetIOSuite` but logical type mapping is not being tested. Author: hyukjinkwon Author: Hyukjin Kwon Closes #9660 from HyukjinKwon/SPARK-11694. --- .../datasources/parquet/ParquetIOSuite.scala | 39 ++++++++++++++----- .../datasources/parquet/ParquetTest.scala | 17 ++++++++ 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 82a42d68fedc1..78df363ade5c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -91,6 +91,33 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11694 Parquet logical types are not being tested properly") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 a(INT_8); + | required int32 b(INT_16); + | required int32 c(DATE); + | required int32 d(DECIMAL(1,0)); + | required int64 e(DECIMAL(10,0)); + | required binary f(UTF8); + | required binary g(ENUM); + | required binary h(DECIMAL(32,0)); + | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + } + } + test("string") { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL @@ -374,16 +401,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Collections.singletonMap( - CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - - ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, - path, - Collections.singletonList( - new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf, extraMetadata) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 8ffb01fc5b584..fdd7697c91f5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import org.apache.parquet.schema.MessageType + import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -117,6 +119,21 @@ private[sql] trait ParquetTest extends SQLTestUtils { ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) } + /** + * This is an overloaded version of `writeMetadata` above to allow writing customized + * Parquet schema. + */ + protected def writeMetadata( + parquetSchema: MessageType, path: Path, configuration: Configuration, + extraMetadata: Map[String, String] = Map.empty[String, String]): Unit = { + val extraMetadataAsJava = extraMetadata.asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadataAsJava, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + protected def readAllFootersWithoutSummaryFiles( path: Path, configuration: Configuration): Seq[Footer] = { val fs = path.getFileSystem(configuration) From 9a73b33a9a440d7312b92df9f6a9b9e17917b582 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Sat, 14 Nov 2015 11:59:37 +0000 Subject: [PATCH 016/149] [MINOR][DOCS] typo in docs/configuration.md `<\code>` end tag missing backslash in docs/configuration.md{L308-L339} ref #8795 Author: Kai Jiang Closes #9715 from vectorijk/minor-typo-docs. --- docs/configuration.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index c276e8e90decf..d961f43acf4ab 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are daily, hourly, minutely or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - sc.dump_profiles(path). If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by - passing a profiler class in as a parameter to the SparkContext constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. From 9461f5ee80e51fd709d6f573c333936cb3c2acc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Lipt=C3=A1k?= Date: Sat, 14 Nov 2015 12:02:02 +0000 Subject: [PATCH 017/149] =?UTF-8?q?[SPARK-11573]=20Correct=20'reflective?= =?UTF-8?q?=20access=20of=20structural=20type=20member=20meth=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …od should be enabled' Scala warnings Author: Gábor Lipták Closes #9550 from gliptak/SPARK-11573. --- .../apache/spark/streaming/receiver/BlockGeneratorSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 2f11b255f1104..92ad9fe52b777 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.receiver import scala.collection.mutable +import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ From 22e96b87fb0a0eb4f2f1a8fc29a742ceabff952a Mon Sep 17 00:00:00 2001 From: Rohan Bhanderi Date: Sat, 14 Nov 2015 13:38:53 +0000 Subject: [PATCH 018/149] Typo in comment: use 2 seconds instead of 1 Use 2 seconds batch size as duration specified in JavaStreamingContext constructor is 2000 ms Author: Rohan Bhanderi Closes #9714 from RohanBhanderi/patch-2. --- .../org/apache/spark/examples/streaming/JavaKafkaWordCount.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 16ae9a3319ee2..337f8ffb5bfb0 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -66,7 +66,7 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); - // Create the context with a 1 second batch size + // Create the context with 2 seconds batch size JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); int numThreads = Integer.parseInt(args[3]); From d83c2f9f0b08d6d5d369d9fae04cdb15448e7f0d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 14 Nov 2015 21:04:18 -0800 Subject: [PATCH 019/149] [SPARK-11736][SQL] Add monotonically_increasing_id to function registry. https://issues.apache.org/jira/browse/SPARK-11736 Author: Yin Huai Closes #9703 from yhuai/MonotonicallyIncreasingID. --- .../apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 3 ++- .../scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 870808aa560e5..a8f4d257acd0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -281,7 +281,8 @@ object FunctionRegistry { expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), - expression[InputFileName]("input_file_name") + expression[InputFileName]("input_file_name"), + expression[MonotonicallyIncreasingID]("monotonically_increasing_id") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 010df2a341589..8674da7a79c81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -563,6 +563,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { df.select(monotonicallyIncreasingId()), Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil ) + checkAnswer( + df.select(expr("monotonically_increasing_id()")), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) } test("sparkPartitionId") { From d22fc10887fdc6a86f6122648a823d0d37d4d795 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 15 Nov 2015 10:33:53 -0800 Subject: [PATCH 020/149] [SPARK-11734][SQL] Rename TungstenProject -> Project, TungstenSort -> Sort I didn't remove the old Sort operator, since we still use it in randomized tests. I moved it into test module and renamed it ReferenceSort. Author: Reynold Xin Closes #9700 from rxin/SPARK-11734. --- .../apache/spark/sql/execution/Exchange.scala | 7 +- .../sql/execution/{sort.scala => Sort.scala} | 55 +----------- .../spark/sql/execution/SparkPlanner.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 16 +--- .../spark/sql/execution/basicOperators.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 6 +- .../spark/sql/execution/ReferenceSort.scala | 61 +++++++++++++ .../execution/RowFormatConvertersSuite.scala | 4 +- .../spark/sql/execution/SortSuite.scala | 69 +++++++++++++-- .../sql/execution/TungstenSortSuite.scala | 86 ------------------- .../execution/metric/SQLMetricsSuite.scala | 12 +-- .../execution/HiveTypeCoercionSuite.scala | 4 +- .../ParquetHadoopFsRelationSuite.scala | 2 +- 15 files changed, 148 insertions(+), 184 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{sort.scala => Sort.scala} (65%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index a161cf0a3185b..62cbc518e02af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -51,7 +51,7 @@ case class Exchange( } val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" - s"${simpleNodeName}${extraInfo}" + s"$simpleNodeName$extraInfo" } /** @@ -475,10 +475,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - sqlContext.planner.BasicOperators.getSortOperator( - requiredOrdering, - global = false, - child) + Sort(requiredOrdering, global = false, child = child) } else { child } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala similarity index 65% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 52ef00ef5b283..24207cb46fd29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,68 +17,22 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines various sort operators. -//////////////////////////////////////////////////////////////////////////////////////////////////// /** - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitionsInternal( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( - TaskContext.get(), ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy(), null))) - val baseIterator = sorter.iterator.map(_._1) - val context = TaskContext.get() - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * Optimized version of [[Sort]] that operates on binary data (implemented as part of - * Project Tungsten). + * Performs (external) sorting. * * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ - -case class TungstenSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -107,7 +61,7 @@ case class TungstenSort( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -143,5 +97,4 @@ case class TungstenSort( sortedIterator } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b7c5476346b2a..6e9a4df828246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -80,7 +80,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { filterCondition.map(Filter(_, scan)).getOrElse(scan) } else { val scan = scanBuilder((projectSet ++ filterSet).toSeq) - TungstenProject(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 67201a2c191cd..3d4ce633c07c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -302,16 +302,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions - /** - * Picks an appropriate sort operator. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ - def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - execution.TungstenSort(sortExprs, global, child) - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -339,11 +329,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - getSortOperator(sortExprs, global = false, planLater(child)) :: Nil + execution.Sort(sortExprs, global = false, child = planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - getSortOperator(sortExprs, global, planLater(child)):: Nil + execution.Sort(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => - execution.TungstenProject(projectList, planLater(child)) :: Nil + execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 07925c62cd386..e79092efdaa3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} -case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 824c89a90eb8a..9bbbfa7c77cba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -343,7 +343,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation) - execution.TungstenProject( + execution.Project( projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8674da7a79c81..3eae3f6d85066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.TungstenProject +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -619,7 +619,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { - case tungstenProject: TungstenProject => tungstenProject + case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 8c41d79dae817..be53ec3e271c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -365,7 +365,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } @@ -381,7 +381,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.nonEmpty) { + if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") } } @@ -398,7 +398,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala new file mode 100644 index 0000000000000..9575d26fd123f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + + +/** + * A reference sort implementation used to compare against our normal sort. + */ +case class ReferenceSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index b3fceeab64cfe..6876ab0f02b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -33,9 +33,9 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 847c188a30333..e5d34be4c65e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{RandomDataGenerator, Row} + +/** + * Test sorting. Many of the test cases generate random data and compares the sorted result with one + * sorted by a reference implementation ([[ReferenceSort]]). + */ class SortSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder - // This test was originally added as an example of how to use [[SparkPlanTest]]; - // it's not designed to be a comprehensive test of ExternalSort. test("basic sorting using ExternalSort") { val input = Seq( @@ -36,14 +43,66 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkThatPlansAgree( + inputDf, + p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala deleted file mode 100644 index 7c860d1d58d5a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - -/** - * A test suite that generates randomized data to test the [[TungstenSort]] operator. - */ -class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - test("sort followed by limit") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - - test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil - val stringLength = 1024 * 1024 * 2 - checkThatPlansAgree( - Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - - test("sorting updates peak execution memory") { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), - (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), - sortAnswers = false) - } - } - - // Test sorting on different data types - for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = sqlContext.createDataFrame( - sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - checkThatPlansAgree( - inputDf, - plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 486bfbbd70887..5e2b4154dd7ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -114,17 +114,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) val df = person.select('name) testSparkPlanMetrics(df, 1, Map( - 0L ->("TungstenProject", Map( - "number of rows" -> 2L))) - ) - } - - test("TungstenProject metrics") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("TungstenProject", Map( + 0L ->("Project", Map( "number of rows" -> 2L))) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 4cf4e13890294..5bd323ea096a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} -import org.apache.spark.sql.execution.TungstenProject +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.hive.test.TestHive /** @@ -44,7 +44,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.executedPlan.collect { - case e: TungstenProject => e + case e: Project => e }.head // No cast expression introduced diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index b6db6225805a1..e866493ee6c96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -151,7 +151,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val df = sqlContext.read.parquet(path).filter('a === 0).select('b) val physicalPlan = df.queryExecution.executedPlan - assert(physicalPlan.collect { case p: execution.TungstenProject => p }.length === 1) + assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) } } From 64e55511033afb6ef42be142eb371bfbc31f5230 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 15 Nov 2015 13:23:05 -0800 Subject: [PATCH 021/149] [SPARK-11672][ML] set active SQLContext in JavaDefaultReadWriteSuite The same as #9694, but for Java test suite. yhuai Author: Xiangrui Meng Closes #9719 from mengxr/SPARK-11672.4. --- .../apache/spark/ml/util/JavaDefaultReadWriteSuite.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index c39538014be81..01ff1ea658610 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -32,17 +32,23 @@ public class JavaDefaultReadWriteSuite { JavaSparkContext jsc = null; + SQLContext sqlContext = null; File tempDir = null; @Before public void setUp() { jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + SQLContext.clearActive(); + sqlContext = new SQLContext(jsc); + SQLContext.setActive(sqlContext); tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } @After public void tearDown() { + sqlContext = null; + SQLContext.clearActive(); if (jsc != null) { jsc.stop(); jsc = null; @@ -64,7 +70,6 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - SQLContext sqlContext = new SQLContext(jsc); instance.write().context(sqlContext).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); From 3e2e1873b2762d07e49de8f9ea709bf3fa2d171c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 15 Nov 2015 13:59:59 -0800 Subject: [PATCH 022/149] [SPARK-11738] [SQL] Making ArrayType orderable https://issues.apache.org/jira/browse/SPARK-11738 Author: Yin Huai Closes #9718 from yhuai/makingArrayOrderable. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 32 +---- .../expressions/codegen/CodeGenerator.scala | 43 ++++++ .../expressions/collectionOperations.scala | 2 + .../sql/catalyst/expressions/ordering.scala | 6 + .../spark/sql/catalyst/util/TypeUtils.scala | 1 + .../spark/sql/types/AbstractDataType.scala | 1 + .../apache/spark/sql/types/ArrayType.scala | 48 +++++++ .../analysis/AnalysisErrorSuite.scala | 37 ++++-- .../ExpressionTypeCheckingSuite.scala | 32 ++--- .../sql/catalyst/analysis/TestRelations.scala | 3 + .../expressions/CodeGenerationSuite.scala | 36 ----- .../catalyst/expressions/OrderingSuite.scala | 124 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 12 +- .../execution/AggregationQuerySuite.scala | 52 ++++++++ 14 files changed, 335 insertions(+), 94 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5a4b0c1e39ce1..7b2c93d63d673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -137,32 +137,14 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkSupportedGroupingDataType( - expressionString: String, - dataType: DataType): Unit = dataType match { - case BinaryType => - failAnalysis(s"expression $expressionString cannot be used in " + - s"grouping expression because it is in binary type or its inner field is " + - s"in binary type") - case a: ArrayType => - failAnalysis(s"expression $expressionString cannot be used in " + - s"grouping expression because it is in array type or its inner field is " + - s"in array type") - case m: MapType => - failAnalysis(s"expression $expressionString cannot be used in " + - s"grouping expression because it is in map type or its inner field is " + - s"in map type") - case s: StructType => - s.fields.foreach { f => - checkSupportedGroupingDataType(expressionString, f.dataType) - } - case udt: UserDefinedType[_] => - checkSupportedGroupingDataType(expressionString, udt.sqlType) - case _ => // OK - } - def checkValidGroupingExprs(expr: Expression): Unit = { - checkSupportedGroupingDataType(expr.prettyString, expr.dataType) + // Check if the data type of expr is orderable. + if (!RowOrdering.isOrderable(expr.dataType)) { + failAnalysis( + s"expression ${expr.prettyString} cannot be used as a grouping expression " + + s"because its data type ${expr.dataType.simpleString} is not a orderable " + + s"data type.") + } if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ccd91d3549b53..1718cfbd35332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -267,6 +267,49 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" + case array: ArrayType => + val elementType = array.elementType + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val compareFunc = freshName("compareArray") + val minLength = freshName("minLength") + val funcCode: String = + s""" + public int $compareFunc(ArrayData a, ArrayData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < $minLength; i++) { + boolean $isNullA = a.isNullAt(i); + boolean $isNullB = b.isNullAt(i); + if ($isNullA && $isNullB) { + // Nothing + } else if ($isNullA) { + return -1; + } else if ($isNullB) { + return 1; + } else { + ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; + ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + int comp = ${genComp(elementType, elementA, elementB)}; + if (comp != 0) { + return comp; + } + } + } + + if (lengthA < lengthB) { + return -1; + } else if (lengthA > lengthB) { + return 1; + } + return 0; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2cf19b939f734..741ad1f3efd8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6407c73bc97d9..6112259fed619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case a: ArrayType if order.direction == Ascending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: ArrayType if order.direction == Descending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => @@ -86,6 +90,8 @@ object RowOrdering { case NullType => true case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index bcf4d78fb9371..f603cbfb0cc21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -57,6 +57,7 @@ object TypeUtils { def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1d2d007c2b4d2..a5ae8bb0e5eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -84,6 +84,7 @@ private[sql] object TypeCollection { * Types that can be ordered/compared. In the long run we should probably make this a trait * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. */ + // TODO: Should we consolidate this with RowOrdering.isOrderable? val Ordered = TypeCollection( BooleanType, ByteType, ShortType, IntegerType, LongType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5770f59b53077..a001eadcc61d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.ArrayData import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi +import scala.math.Ordering + object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + @transient + private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { + private[this] val elementOrdering: Ordering[Any] = elementType match { + case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] + case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") + } + + def compare(x: ArrayData, y: ArrayData): Int = { + val leftArray = x + val rightArray = y + val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) + var i = 0 + while (i < minLength) { + val isNullLeft = leftArray.isNullAt(i) + val isNullRight = rightArray.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = + elementOrdering.compare( + leftArray.get(i, elementType), + rightArray.get(i, elementType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + if (leftArray.numElements() < rightArray.numElements()) { + return -1 + } else if (leftArray.numElements() > rightArray.numElements()) { + return 1 + } else { + return 0 + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 2e7c3bd67b554..ee435578743fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ import scala.beans.{BeanProperty, BeanInfo} @@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { } @BeanInfo -private[sql] case class UngroupableData(@BeanProperty data: Array[Int]) +private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { - override def sqlType: DataType = ArrayType(IntegerType) + override def sqlType: DataType = MapType(IntegerType, IntegerType) - override def serialize(obj: Any): ArrayData = { + override def serialize(obj: Any): MapData = { obj match { - case groupableData: UngroupableData => new GenericArrayData(groupableData.data) + case groupableData: UngroupableData => + val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(groupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) } } override def deserialize(datum: Any): UngroupableData = { datum match { - case data: Array[Int] => UngroupableData(data) + case data: MapData => + val keyArray = data.keyArray().array + val valueArray = data.valueArray().array + assert(keyArray.length == valueArray.length) + val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] + UngroupableData(mapData) } } @@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by unsupported column types", - listRelation.orderBy('list.asc), - "sort" :: "type" :: "array" :: Nil) + mapRelation.orderBy('map.asc), + "sort" :: "type" :: "map" :: Nil) errorTest( "non-boolean filters", @@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest { case true => assertAnalysisSuccess(plan, true) case false => - assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil) + assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil) } - } val supportedDataTypes = Seq( - StringType, + StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, + ArrayType(IntegerType), new StructType() .add("f1", FloatType, nullable = true) .add("f2", StringType, nullable = true), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), new GroupableUDT()) supportedDataTypes.foreach { dataType => checkDataType(dataType, shouldSuccess = true) } val unsupportedDataTypes = Seq( - BinaryType, - ArrayType(IntegerType), MapType(StringType, LongType), new StructType() .add("f1", FloatType, nullable = true) - .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + .add("f2", MapType(StringType, LongType), nullable = true), new UngroupableUDT()) unsupportedDataTypes.foreach { dataType => checkDataType(dataType, shouldSuccess = false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index b902982add8ff..ba1866efc84e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -32,7 +32,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, - 'complexField.array(StringType)) + 'arrayField.array(StringType), + 'mapField.map(StringType, LongType)) def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { @@ -90,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type") assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type") - assertError(MaxOf('complexField, 'complexField), + assertError(MaxOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(MinOf('complexField, 'complexField), + assertError(MinOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") } @@ -109,20 +110,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) + assertErrorForDifferingTypes(EqualTo('intField, 'mapField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(LessThan('complexField, 'complexField), + assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(LessThanOrEqual('complexField, 'complexField), + assertError(LessThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThan('complexField, 'complexField), + assertError(GreaterThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThanOrEqual('complexField, 'complexField), + assertError(GreaterThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(If('intField, 'stringField, 'stringField), @@ -130,10 +131,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), @@ -147,9 +148,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) + assertSuccess(Min('arrayField)) - assertError(Min('complexField), "min does not support ordering on type") - assertError(Max('complexField), "max does not support ordering on type") + assertError(Min('mapField), "min does not support ordering on type") + assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } @@ -184,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'intField), "Only foldable Expression is allowed") assertError(Round('intField, 'booleanField), "requires int type") - assertError(Round('intField, 'complexField), "requires int type") + assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index 05b870705e7ea..bc07b609a3413 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -48,4 +48,7 @@ object TestRelations { val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) + + val mapRelation = LocalRelation( + AttributeReference("map", MapType(IntegerType, IntegerType))()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5f4a..002ed16dcfe7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, RandomDataGenerator} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -49,40 +47,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } - // Test GenerateOrdering for all common types. For each type, we construct random input rows that - // contain two columns of that type, then for pairs of randomly-generated rows we check that - // GenerateOrdering agrees with RowOrdering. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - test(s"GenerateOrdering with $dataType") { - val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) - val genOrdering = GenerateOrdering.generate( - BoundReference(0, dataType, nullable = true).asc :: - BoundReference(1, dataType, nullable = true).asc :: Nil) - val rowType = StructType( - StructField("a", dataType, nullable = true) :: - StructField("b", dataType, nullable = true) :: Nil) - val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) - assume(maybeDataGenerator.isDefined) - val randGenerator = maybeDataGenerator.get - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - for (_ <- 1 to 50) { - val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } - } - } - } - test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala new file mode 100644 index 0000000000000..7ad8657bde128 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.math._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.types._ + +class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { + + def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = { + test(s"compare two arrays: a = $a, b = $b") { + val dataType = ArrayType(IntegerType) + val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] + val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => BoundReference(0, dataType, nullable = true).asc + case Descending => BoundReference(0, dataType, nullable = true).desc + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = GenerateOrdering.generate(sortOrder :: Nil) + Seq(intOrdering, genOrdering).foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) + } + } + } + } + + // Two arrays have the same size. + compareArrays(Seq[Any](), Seq[Any](), 0) + compareArrays(Seq[Any](1), Seq[Any](1), 0) + compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0) + compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1) + + // Two arrays have different sizes. + compareArrays(Seq[Any](), Seq[Any](1), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1) + + // Arrays having nulls. + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1) + compareArrays(Seq[Any](), Seq[Any](null), -1) + compareArrays(Seq[Any](null), Seq[Any](null), 0) + compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0) + compareArrays(Seq[Any](null), Seq[Any](null, null), -1) + compareArrays(Seq[Any](null), Seq[Any](1), -1) + compareArrays(Seq[Any](null), Seq[Any](null, 1), -1) + compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + { + val structType = + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true) + val arrayOfStructType = ArrayType(structType) + val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil + (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3a3f19af1473b..aff9efe4b2b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -308,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null)) ) - val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") - assert(intercept[AnalysisException] { - df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("does not support sorting array of type array")) + val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") + checkAnswer( + df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), + Seq( + Row( + Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)), + Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) + ) val df3 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 61e3e913c23ea..6dde79f74d3d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -132,6 +132,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") + val data3 = Seq[(Seq[Integer], Integer, Integer)]( + (Seq[Integer](1, 1), 10, -10), + (Seq[Integer](null), -60, 60), + (Seq[Integer](1, 1), 30, -30), + (Seq[Integer](1), 30, 30), + (Seq[Integer](2), 1, 1), + (null, -10, 10), + (Seq[Integer](2, 3), -1, null), + (Seq[Integer](2, 3), 1, 1), + (Seq[Integer](2, 3, 4), null, 1), + (Seq[Integer](null), 100, -10), + (Seq[Integer](3), null, 3), + (null, null, null), + (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + data3.write.saveAsTable("agg3") + val emptyDF = sqlContext.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) @@ -146,6 +162,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te override def afterAll(): Unit = { sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.sql("DROP TABLE IF EXISTS agg3") sqlContext.dropTempTable("emptyTable") } @@ -266,6 +283,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(100, null) :: Row(null, 3) :: Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT key + |FROM agg3 + """.stripMargin), + Row(Seq[Integer](1, 1)) :: + Row(Seq[Integer](null)) :: + Row(Seq[Integer](1)) :: + Row(Seq[Integer](2)) :: + Row(null) :: + Row(Seq[Integer](2, 3)) :: + Row(Seq[Integer](2, 3, 4)) :: + Row(Seq[Integer](3)) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg3 + |GROUP BY value1, key + """.stripMargin), + Row(10, Seq[Integer](1, 1)) :: + Row(-60, Seq[Integer](null)) :: + Row(30, Seq[Integer](1, 1)) :: + Row(30, Seq[Integer](1)) :: + Row(1, Seq[Integer](2)) :: + Row(-10, null) :: + Row(-1, Seq[Integer](2, 3)) :: + Row(1, Seq[Integer](2, 3)) :: + Row(null, Seq[Integer](2, 3, 4)) :: + Row(100, Seq[Integer](null)) :: + Row(null, Seq[Integer](3)) :: + Row(null, null) :: Nil) } test("case in-sensitive resolution") { From 72c1d68b4ab6acb3f85971e10947caabb4bd846d Mon Sep 17 00:00:00 2001 From: Yu Gao Date: Sun, 15 Nov 2015 14:53:59 -0800 Subject: [PATCH 023/149] [SPARK-10181][SQL] Do kerberos login for credentials during hive client initialization On driver process start up, UserGroupInformation.loginUserFromKeytab is called with the principal and keytab passed in, and therefore static var UserGroupInfomation,loginUser is set to that principal with kerberos credentials saved in its private credential set, and all threads within the driver process are supposed to see and use this login credentials to authenticate with Hive and Hadoop. However, because of IsolatedClientLoader, UserGroupInformation class is not shared for hive metastore clients, and instead it is loaded separately and of course not able to see the prepared kerberos login credentials in the main thread. The first proposed fix would cause other classloader conflict errors, and is not an appropriate solution. This new change does kerberos login during hive client initialization, which will make credentials ready for the particular hive client instance. yhuai Please take a look and let me know. If you are not the right person to talk to, could you point me to someone responsible for this? Author: Yu Gao Author: gaoyu Author: Yu Gao Closes #9272 from yolandagao/master. --- .../org/apache/spark/deploy/SparkSubmit.scala | 17 ++++++++++--- .../spark/sql/hive/client/ClientWrapper.scala | 24 ++++++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 84ae122f44370..09d2ec90c9333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -39,7 +39,7 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} -import org.apache.spark.{SparkUserAppException, SPARK_VERSION} +import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION} import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -521,8 +521,19 @@ object SparkSubmit { sysProps.put("spark.yarn.isPython", "true") } if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when the keytab is specified") - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + require(args.keytab != null, "Keytab must be specified when principal is specified") + if (!new File(args.keytab).exists()) { + throw new SparkException(s"Keytab file: ${args.keytab} does not exist") + } else { + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index f1c2489b38271..598ccdeee4ad2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -32,9 +32,10 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.VersionInfo -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, SparkException, Logging} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -149,6 +150,27 @@ private[hive] class ClientWrapper( val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) + + // Set up kerberos credentials for UserGroupInformation.loginUser within + // current class loader + // Instead of using the spark conf of the current spark context, a new + // instance of SparkConf is needed for the original value of spark.yarn.keytab + // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the + // keytab configuration for the link name in distributed cache + val sparkConf = new SparkConf + if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { + val principalName = sparkConf.get("spark.yarn.principal") + val keytabFileName = sparkConf.get("spark.yarn.keytab") + if (!new File(keytabFileName).exists()) { + throw new SparkException(s"Keytab file: ${keytabFileName}" + + " specified in spark.yarn.keytab does not exist") + } else { + logInfo("Attempting to login to Kerberos" + + s" using principal: ${principalName} and keytab: ${keytabFileName}") + UserGroupInformation.loginUserFromKeytab(principalName, keytabFileName) + } + } + val ret = try { val initialConf = new HiveConf(classOf[SessionState]) // HiveConf is a Hadoop Configuration, which has a field of classLoader and From d7d9fa0b8750166f8b74f9bc321df26908683a8b Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 15 Nov 2015 19:15:27 -0800 Subject: [PATCH 024/149] [SPARK-11086][SPARKR] Use dropFactors column-wise instead of nested loop when createDataFrame Use `dropFactors` column-wise instead of nested loop when `createDataFrame` from a `data.frame` At this moment SparkR createDataFrame is using nested loop to convert factors to character when called on a local data.frame. It works but is incredibly slow especially with data.table (~ 2 orders of magnitude compared to PySpark / Pandas version on a DateFrame of size 1M rows x 2 columns). A simple improvement is to apply `dropFactor `column-wise and then reshape output list. It should at least partially address [SPARK-8277](https://issues.apache.org/jira/browse/SPARK-8277). Author: zero323 Closes #9099 from zero323/SPARK-11086. --- R/pkg/R/SQLContext.R | 54 +++++++++++++++++++------------- R/pkg/inst/tests/test_sparkSQL.R | 16 ++++++++++ 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index fd013fdb304df..a62b25fde926d 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -17,27 +17,33 @@ # SQLcontext.R: SQLContext-driven functions + +# Map top level R type to SQL type +getInternalType <- function(x) { + # class of POSIXlt is c("POSIXlt" "POSIXt") + switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + struct = "struct", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) +} + #' infer the SQL type infer_type <- function(x) { if (is.null(x)) { stop("can not infer type from NULL") } - # class of POSIXlt is c("POSIXlt" "POSIXt") - type <- switch(class(x)[[1]], - integer = "integer", - character = "string", - logical = "boolean", - double = "double", - numeric = "double", - raw = "binary", - list = "array", - struct = "struct", - environment = "map", - Date = "date", - POSIXlt = "timestamp", - POSIXct = "timestamp", - stop(paste("Unsupported type for DataFrame:", class(x)))) + type <- getInternalType(x) if (type == "map") { stopifnot(length(x) > 0) @@ -90,19 +96,25 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 if (is.null(schema)) { schema <- names(data) } - n <- nrow(data) - m <- ncol(data) + # get rid of factor type - dropFactor <- function(x) { + cleanCols <- function(x) { if (is.factor(x)) { as.character(x) } else { x } } - data <- lapply(1:n, function(i) { - lapply(1:m, function(j) { dropFactor(data[i,j]) }) - }) + + # drop factors and wrap lists + data <- setNames(lapply(data, cleanCols), NULL) + + # check if all columns have supported type + lapply(data, getInternalType) + + # convert to rows + args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) + data <- do.call(mapply, append(args, data)) } if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index af024e6183a37..8ff06276599e2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -242,6 +242,14 @@ test_that("create DataFrame from list or data.frame", { expect_equal(count(df), 3) ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) + + irisdf <- createDataFrame(sqlContext, iris) + iris_collected <- collect(irisdf) + expect_equivalent(iris_collected[,-5], iris[,-5]) + expect_equal(iris_collected$Species, as.character(iris$Species)) + + mtcarsdf <- createDataFrame(sqlContext, mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) }) test_that("create DataFrame with different data types", { @@ -283,6 +291,14 @@ test_that("create DataFrame with complex types", { expect_equal(s$b, 3L) }) +test_that("create DataFrame from a data.frame with complex types", { + ldf <- data.frame(row.names=1:2) + ldf$a_list <- list(list(1, 2), list(3, 4)) + sdf <- createDataFrame(sqlContext, ldf) + + expect_equivalent(ldf, collect(sdf)) +}) + # For test map type and struct type in DataFrame mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", From 835a79d78ee879a3c36dde85e5b3591243bf3957 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 15 Nov 2015 19:29:09 -0800 Subject: [PATCH 025/149] [SPARK-10500][SPARKR] sparkr.zip cannot be created if /R/lib is unwritable The basic idea is that: The archive of the SparkR package itself, that is sparkr.zip, is created during build process and is contained in the Spark binary distribution. No change to it after the distribution is installed as the directory it resides ($SPARK_HOME/R/lib) may not be writable. When there is R source code contained in jars or Spark packages specified with "--jars" or "--packages" command line option, a temporary directory is created by calling Utils.createTempDir() where the R packages built from the R source code will be installed. The temporary directory is writable, and won't interfere with each other when there are multiple SparkR sessions, and will be deleted when this SparkR session ends. The R binary packages installed in the temporary directory then are packed into an archive named rpkg.zip. sparkr.zip and rpkg.zip are distributed to the cluster in YARN modes. The distribution of rpkg.zip in Standalone modes is not supported in this PR, and will be address in another PR. Various R files are updated to accept multiple lib paths (one is for SparkR package, the other is for other R packages) so that these package can be accessed in R. Author: Sun Rui Closes #9390 from sun-rui/SPARK-10500. --- R/install-dev.bat | 6 +++ R/install-dev.sh | 4 ++ R/pkg/R/sparkR.R | 14 +++++- R/pkg/inst/profile/general.R | 3 +- R/pkg/inst/worker/daemon.R | 5 ++- R/pkg/inst/worker/worker.R | 3 +- .../org/apache/spark/api/r/RBackend.scala | 1 + .../scala/org/apache/spark/api/r/RRDD.scala | 4 +- .../scala/org/apache/spark/api/r/RUtils.scala | 37 +++++++++++++--- .../apache/spark/deploy/RPackageUtils.scala | 26 ++++++++--- .../org/apache/spark/deploy/RRunner.scala | 5 ++- .../org/apache/spark/deploy/SparkSubmit.scala | 43 +++++++++++++++---- .../spark/deploy/SparkSubmitSuite.scala | 5 +-- make-distribution.sh | 1 + 14 files changed, 121 insertions(+), 36 deletions(-) diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..ed1c91ae3a0ff 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,9 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd + diff --git a/R/install-dev.sh b/R/install-dev.sh index 59d98c9c7a646..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + popd > /dev/null diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index ebe2b2b8dc1d0..7ff3fa628b9ca 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -48,6 +48,12 @@ sparkR.stop <- function() { } } + # Remove the R package lib path from .libPaths() + if (exists(".libPath", envir = env)) { + libPath <- get(".libPath", envir = env) + .libPaths(.libPaths()[.libPaths() != libPath]) + } + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -155,14 +161,20 @@ sparkR.init <- function( f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) + rLibPath <- readString(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || - length(monitorPort) == 0 || monitorPort == 0) { + length(monitorPort) == 0 || monitorPort == 0 || + length(rLibPath) != 1) { stop("JVM failed to launch") } assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) + if (rLibPath != "") { + assign(".libPath", rLibPath, envir = .sparkREnv) + .libPaths(c(rLibPath, .libPaths())) + } } .sparkREnv$backendPort <- backendPort diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 2a8a8213d0849..c55fe9ba7af7a 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -17,6 +17,7 @@ .First <- function() { packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") - .libPaths(c(packageDir, .libPaths())) + dirs <- strsplit(packageDir, ",")[[1]] + .libPaths(c(dirs, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3584b418a71a9..f55beac6c8c07 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,10 +18,11 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") -script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") +dirs <- strsplit(rLibDir, ",")[[1]] +script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") # preload SparkR package, speedup worker -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 0c3b0d1f4be20..3ae072beca11b 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -35,10 +35,11 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require # SparkR namespace -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index b7e72d4d0ed0b..8b3be0da2c8c4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -113,6 +113,7 @@ private[spark] object RBackend extends Logging { val dos = new DataOutputStream(new FileOutputStream(f)) dos.writeInt(boundPort) dos.writeInt(listenPort) + SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 6b418e908cb53..7509b3d3f44bb 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -400,14 +400,14 @@ private[r] object RRDD { val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir + "/SparkR/worker/" + script + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) // and confuses worker script which tries to load a non-existent file pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index fd5646b5b6372..16157414fd120 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -23,6 +23,10 @@ import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} private[spark] object RUtils { + // Local path where R binary packages built from R source code contained in the spark + // packages specified with "--packages" or "--jars" command line option reside. + var rPackages: Option[String] = None + /** * Get the SparkR package path in the local spark distribution. */ @@ -34,11 +38,15 @@ private[spark] object RUtils { } /** - * Get the SparkR package path in various deployment modes. + * Get the list of paths for R packages in various deployment modes, of which the first + * path is for the SparkR package itself. The second path is for R packages built as + * part of Spark Packages, if any exist. Spark Packages can be provided through the + * "--packages" or "--jars" command line options. + * * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` * and environment variable `SPARK_HOME` are set. */ - def sparkRPackagePath(isDriver: Boolean): String = { + def sparkRPackagePath(isDriver: Boolean): Seq[String] = { val (master, deployMode) = if (isDriver) { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) @@ -51,15 +59,30 @@ private[spark] object RUtils { val isYarnClient = master != null && master.contains("yarn") && deployMode == "client" // In YARN mode, the SparkR package is distributed as an archive symbolically - // linked to the "sparkr" file in the current directory. Note that this does not apply - // to the driver in client mode because it is run outside of the cluster. + // linked to the "sparkr" file in the current directory and additional R packages + // are distributed as an archive symbolically linked to the "rpkg" file in the + // current directory. + // + // Note that this does not apply to the driver in client mode because it is run + // outside of the cluster. if (isYarnCluster || (isYarnClient && !isDriver)) { - new File("sparkr").getAbsolutePath + val sparkRPkgPath = new File("sparkr").getAbsolutePath + val rPkgPath = new File("rpkg") + if (rPkgPath.exists()) { + Seq(sparkRPkgPath, rPkgPath.getAbsolutePath) + } else { + Seq(sparkRPkgPath) + } } else { // Otherwise, assume the package is local // TODO: support this for Mesos - localSparkRPackagePath.getOrElse { - throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + val sparkRPkgPath = localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + if (!rPackages.isEmpty) { + Seq(sparkRPkgPath, rPackages.get) + } else { + Seq(sparkRPkgPath) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 7d160b6790eaa..d46dc87a92c97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -100,20 +100,29 @@ private[deploy] object RPackageUtils extends Logging { * Runs the standard R package installation code to build the R package from source. * Multiple runs don't cause problems. */ - private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = { + private def rPackageBuilder( + dir: File, + printStream: PrintStream, + verbose: Boolean, + libDir: String): Boolean = { // this code should be always running on the driver. - val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse( - throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")) val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator) - val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg) + val installCmd = baseInstallCmd ++ Seq(libDir, pathToPkg) if (verbose) { print(s"Building R package with the command: $installCmd", printStream) } try { val builder = new ProcessBuilder(installCmd.asJava) builder.redirectErrorStream(true) + + // Put the SparkR package directory into R library search paths in case this R package + // may depend on SparkR. val env = builder.environment() - env.clear() + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) + env.put("R_PROFILE_USER", + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + val process = builder.start() new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start() process.waitFor() == 0 @@ -170,8 +179,11 @@ private[deploy] object RPackageUtils extends Logging { if (checkManifestForR(jar)) { print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) val rSource = extractRFolder(jar, printStream, verbose) + if (RUtils.rPackages.isEmpty) { + RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) + } try { - if (!rPackageBuilder(rSource, printStream, verbose)) { + if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { print(s"ERROR: Failed to build R package in $file.", printStream) print(RJarDoc, printStream) } @@ -208,7 +220,7 @@ private[deploy] object RPackageUtils extends Logging { } } - /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */ + /** Zips all the R libraries built for distribution to the cluster. */ private[deploy] def zipRLibraries(dir: File, name: String): File = { val filesToBundle = listFilesRecursively(dir, Seq(".zip")) // create a zip file from scratch, do not append to existing file. diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index ed183cf16a9cb..661f7317c674b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -82,9 +82,10 @@ object RRunner { val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) - env.put("SPARKR_PACKAGE_DIR", rPackageDir) + // Put the R package directories into an env variable of comma-separated paths + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", - Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 09d2ec90c9333..2e912b59afdb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -83,6 +83,7 @@ object SparkSubmit { private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -362,22 +363,46 @@ object SparkSubmit { } } - // In YARN mode for an R app, add the SparkR package archive to archives - // that can be distributed with the job + // In YARN mode for an R app, add the SparkR package archive and the R package + // archive containing all of the built R libraries to archives so that they can + // be distributed with the job if (args.isR && clusterManager == YARN) { - val rPackagePath = RUtils.localSparkRPackagePath - if (rPackagePath.isEmpty) { + val sparkRPackagePath = RUtils.localSparkRPackagePath + if (sparkRPackagePath.isEmpty) { printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") } - val rPackageFile = - RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE) - if (!rPackageFile.exists()) { + val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!sparkRPackageFile.exists()) { printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } - val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString + // Distribute the SparkR package. // Assigns a symbol link name "sparkr" to the shipped package. - args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr") + + // Distribute the R package archive containing all the built R packages. + if (!RUtils.rPackages.isEmpty) { + val rPackageFile = + RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit("Failed to zip all the built R packages.") + } + + val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString + // Assigns a symbol link name "rpkg" to the shipped package. + args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg") + } + } + + // TODO: Support distributing R packages with standalone cluster + if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + } + + // TODO: Support SparkR with mesos cluster + if (args.isR && clusterManager == MESOS) { + printErrorAndExit("SparkR is not supported for Mesos cluster.") } // If we're running a R app, set the main class to our specific R runner diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 66a50512003dc..42e748ec6d528 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -369,9 +370,6 @@ class SparkSubmitSuite } test("correctly builds R packages included in a jar with --packages") { - // TODO(SPARK-9603): Building a package to $SPARK_HOME/R/lib is unavailable on Jenkins. - // It's hard to write the test in SparkR (because we can't create the repository dynamically) - /* assume(RUtils.isRInstalled, "R isn't installed on this machine.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -389,7 +387,6 @@ class SparkSubmitSuite rScriptDir) runSparkSubmit(args) } - */ } test("resolves command line argument paths correctly") { diff --git a/make-distribution.sh b/make-distribution.sh index e1c2afdbc6d87..d7d27e253f721 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -220,6 +220,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested From b58765caa6d7e6933050565c5d423c45e7e70ba6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 15 Nov 2015 21:10:46 -0800 Subject: [PATCH 026/149] [SPARK-9928][SQL] Removal of LogicalLocalTable LogicalLocalTable in ExistingRDD.scala is replaced by localRelation in LocalRelation.scala? Do you know any reason why we still keep this class? Author: gatorsmile Closes #9717 from gatorsmile/LogicalLocalTable. --- .../spark/sql/execution/ExistingRDD.scala | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 7a466cf6a0a94..8b41d3d3d892e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -110,25 +110,3 @@ private[sql] object PhysicalRDD { PhysicalRDD(output, rdd, relation.toString, relation.isInstanceOf[HadoopFsRelation]) } } - -/** Logical plan node for scanning data from a local collection. */ -private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - override def children: Seq[LogicalPlan] = Nil - - override def newInstance(): this.type = - LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] - - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LogicalRDD(_, otherRDD) => rows == rows - case _ => false - } - - @transient override lazy val statistics: Statistics = Statistics( - // TODO: Improve the statistics estimation. - // This is made small enough so it can be broadcasted. - sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 - ) -} From fd50fa4c3eff42e8adeeabe399ddba0edac930c8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 15 Nov 2015 22:38:30 -0800 Subject: [PATCH 027/149] Revert "[SPARK-11572] Exit AsynchronousListenerBus thread when stop() is called" This reverts commit 3e0a6cf1e02a19b37c68d3026415d53bb57a576b. --- .../org/apache/spark/util/AsynchronousListenerBus.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index b3b54af972cb4..c20627b056bef 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -66,12 +66,15 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri processingEvent = true } try { - if (stopped.get()) { + val event = eventQueue.poll + if (event == null) { // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } return } - val event = eventQueue.poll - assert(event != null, "event queue was empty but the listener bus was not stopped") postToAll(event) } finally { self.synchronized { From 42de5253f327bd7ee258b0efb5024f3847fa3b51 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 16 Nov 2015 00:06:14 -0800 Subject: [PATCH 028/149] [SPARK-11745][SQL] Enable more JSON parsing options This patch adds the following options to the JSON data source, for dealing with non-standard JSON files: * `allowComments` (default `false`): ignores Java/C++ style comment in JSON records * `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names * `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes * `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers (e.g. 00012) To avoid passing a lot of options throughout the json package, I introduced a new JSONOptions case class to define all JSON config options. Also updated documentation to explain these options. Scala ![screen shot 2015-11-15 at 6 12 12 pm](https://cloud.githubusercontent.com/assets/323388/11172965/e3ace6ec-8bc4-11e5-805e-2d78f80d0ed6.png) Python ![screen shot 2015-11-15 at 6 11 28 pm](https://cloud.githubusercontent.com/assets/323388/11172964/e23ed6ee-8bc4-11e5-8216-312f5983acd5.png) Author: Reynold Xin Closes #9724 from rxin/SPARK-11745. --- python/pyspark/sql/readwriter.py | 10 ++ .../apache/spark/sql/DataFrameReader.scala | 22 ++-- .../spark/sql/execution/SparkPlan.scala | 17 +-- .../datasources/json/InferSchema.scala | 34 +++--- .../datasources/json/JSONOptions.scala | 64 ++++++++++ .../datasources/json/JSONRelation.scala | 20 ++- .../datasources/json/JacksonParser.scala | 82 +++++++------ .../json/JsonParsingOptionsSuite.scala | 114 ++++++++++++++++++ .../datasources/json/JsonSuite.scala | 29 ++--- 9 files changed, 286 insertions(+), 106 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 927f4077424dc..7b8ddb9feba34 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -153,6 +153,16 @@ def json(self, path, schema=None): or RDD of Strings storing JSON objects. :param schema: an optional :class:`StructType` for the input schema. + You can set the following JSON-specific options to deal with non-standard JSON files: + * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ + type + * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records + * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names + * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ + quotes + * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ + (e.g. 00012) + >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') >>> df1.dtypes [('age', 'bigint'), ('name', 'string')] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6a194a443ab17..5872fbded3833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType @@ -227,6 +227,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * * @param path input path * @since 1.4.0 */ @@ -255,16 +264,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble - val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean sqlContext.baseRelationToDataFrame( new JSONRelation( Some(jsonRDD), - samplingRatio, - primitivesAsString, - userSpecifiedSchema, - None, - None)(sqlContext) + maybeDataSchema = userSpecifiedSchema, + maybePartitionSpec = None, + userDefinedPartitionColumns = None, + parameters = extraOptions.toMap)(sqlContext) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1b833002f434c..534a3bcb8364d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -221,22 +221,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - log.debug(s"Creating Projection: $expressions, inputSchema: $inputSchema") - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") @@ -282,6 +266,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } + /** * Creates a row ordering for the given schema, in natural ascending order. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index b9914c581a657..922fd5b21167b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -25,33 +25,36 @@ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -private[sql] object InferSchema { + +private[json] object InferSchema { + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def apply( + def infer( json: RDD[String], - samplingRatio: Double = 1.0, columnNameOfCorruptRecords: String, - primitivesAsString: Boolean = false): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) { + configOptions: JSONOptions): StructType = { + require(configOptions.samplingRatio > 0, + s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") + val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { - json.sample(withReplacement = false, samplingRatio, 1) + json.sample(withReplacement = false, configOptions.samplingRatio, 1) } // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) iter.map { row => try { Utils.tryWithResource(factory.createParser(row)) { parser => parser.nextToken() - inferField(parser, primitivesAsString) + inferField(parser, configOptions) } } catch { case _: JsonParseException => @@ -71,14 +74,14 @@ private[sql] object InferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, primitivesAsString) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -95,7 +98,7 @@ private[sql] object InferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, primitivesAsString), + inferField(parser, configOptions), nullable = true) } @@ -107,14 +110,15 @@ private[sql] object InferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser, primitivesAsString)) + elementType = compatibleType( + elementType, inferField(parser, configOptions)) } ArrayType(elementType) - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType - case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType + case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ @@ -178,7 +182,7 @@ private[sql] object InferSchema { /** * Returns the most general data type for two given data types. */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType(t1: DataType, t2: DataType): DataType = { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala new file mode 100644 index 0000000000000..c132ead20e7d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import com.fasterxml.jackson.core.{JsonParser, JsonFactory} + +/** + * Options for the JSON data source. + * + * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. + */ +case class JSONOptions( + samplingRatio: Double = 1.0, + primitivesAsString: Boolean = false, + allowComments: Boolean = false, + allowUnquotedFieldNames: Boolean = false, + allowSingleQuotes: Boolean = true, + allowNumericLeadingZeros: Boolean = false, + allowNonNumericNumbers: Boolean = false) { + + /** Sets config options on a Jackson [[JsonFactory]]. */ + def setJacksonOptions(factory: JsonFactory): Unit = { + factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) + factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames) + factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) + factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) + factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + } +} + + +object JSONOptions { + def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( + samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), + primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), + allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false), + allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), + allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), + allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), + allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index dca638b7f67a5..3e61ba35bea8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -52,13 +52,9 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) new JSONRelation( inputRDD = None, - samplingRatio = samplingRatio, - primitivesAsString = primitivesAsString, maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, @@ -69,8 +65,6 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class JSONRelation( val inputRDD: Option[RDD[String]], - val samplingRatio: Double, - val primitivesAsString: Boolean, val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], @@ -79,6 +73,8 @@ private[sql] class JSONRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) + /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { if (schema.fieldNames.length != schema.fieldNames.distinct.length) { @@ -109,17 +105,16 @@ private[sql] class JSONRelation( classOf[Text]).map(_._2.toString) // get the text line } - override lazy val dataSchema = { + override lazy val dataSchema: StructType = { val jsonSchema = maybeDataSchema.getOrElse { val files = cachedLeafStatuses().filterNot { status => val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.toArray - InferSchema( + InferSchema.infer( inputRDD.getOrElse(createBaseRdd(files)), - samplingRatio, sqlContext.conf.columnNameOfCorruptRecord, - primitivesAsString) + options) } checkConstraints(jsonSchema) @@ -132,10 +127,11 @@ private[sql] class JSONRelation( inputPaths: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) - val rows = JacksonParser( + val rows = JacksonParser.parse( inputRDD.getOrElse(createBaseRdd(inputPaths)), requiredDataSchema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord, + options) rows.mapPartitions { iterator => val unsafeProjection = UnsafeProjection.create(requiredDataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 4f53eeb081b93..bfa1405041058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream +import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -32,18 +31,23 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] object JacksonParser { - def apply( - json: RDD[String], +object JacksonParser { + + def parse( + input: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, schema, columnNameOfCorruptRecords) + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): RDD[InternalRow] = { + + input.mapPartitions { iter => + parseJson(iter, schema, columnNameOfCorruptRecords, configOptions) + } } /** * Parse the current token (and related children) according to a desired schema */ - private[sql] def convertField( + def convertField( factory: JsonFactory, parser: JsonParser, schema: DataType): Any = { @@ -226,9 +230,10 @@ private[sql] object JacksonParser { } private def parseJson( - json: RDD[String], + input: Iterator[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): Iterator[InternalRow] = { def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present @@ -241,37 +246,36 @@ private[sql] object JacksonParser { Seq(row) } - json.mapPartitions { iter => - val factory = new JsonFactory() - - iter.flatMap { record => - if (record.trim.isEmpty) { - Nil - } else { - try { - Utils.tryWithResource(factory.createParser(record)) { parser => - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of " + - "the file (or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") - } + val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) + + input.flatMap { record => + if (record.trim.isEmpty) { + Nil + } else { + try { + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of " + + "the file (or each string in the RDD) is a valid JSON object or " + + "an array of JSON objects.") } - } catch { - case _: JsonProcessingException => - failedRecord(record) } + } catch { + case _: JsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala new file mode 100644 index 0000000000000..4cc0a3a9585d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Test cases for various [[JSONOptions]]. + */ +class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + + test("allowComments off") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowComments on") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowComments", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowSingleQuotes off") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowSingleQuotes on") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowUnquotedFieldNames off") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowUnquotedFieldNames on") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowNumericLeadingZeros off") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowNumericLeadingZeros on") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getLong(0) == 18) + } + + // The following two tests are not really working - need to look into Jackson's + // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. + ignore("allowNonNumericNumbers off") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + ignore("allowNonNumericNumbers on") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getDouble(0).isNaN) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 28b8f02bdf87f..6042b1178affe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -588,7 +588,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { relation.isInstanceOf[JSONRelation], "The DataFrame returned by jsonFile should be based on JSONRelation.") assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) - assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + assert(relation.asInstanceOf[JSONRelation].options.samplingRatio === (0.49 +- 0.001)) val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = @@ -597,7 +597,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) - assert(relationWithSchema.samplingRatio > 0.99) + assert(relationWithSchema.options.samplingRatio > 0.99) } test("Loading a JSON dataset from a text file") { @@ -1165,31 +1165,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("JSONRelation equality test") { val relation0 = new JSONRelation( Some(empty), - 1.0, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), - 1.0, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), - 0.5, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None, + parameters = Map("samplingRatio" -> "0.5"))(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), - 1.0, - false, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1232,7 +1229,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema(empty, 1.0, "") + val emptySchema = InferSchema.infer(empty, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } @@ -1256,7 +1253,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema(emptyRecords, 1.0, "") + val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } From 7f8eb3bf6ed64eefc5472f5c5fb02e2db1e3f618 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 16 Nov 2015 21:30:10 +0800 Subject: [PATCH 029/149] [SPARK-11044][SQL] Parquet writer version fixed as version1 https://issues.apache.org/jira/browse/SPARK-11044 Spark writes a parquet file only with writer version1 ignoring the writer version given by user. So, in this PR, it keeps the writer version if given or sets version1 as default. Author: hyukjinkwon Author: HyukjinKwon Closes #9060 from HyukjinKwon/SPARK-11044. --- .../parquet/CatalystWriteSupport.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 483363d2c1a21..6862dea5e6c3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -429,7 +429,7 @@ private[parquet] object CatalystWriteSupport { def setSchema(schema: StructType, configuration: Configuration): Unit = { schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) configuration.set(SPARK_ROW_SCHEMA, schema.json) - configuration.set( + configuration.setIfUnset( ParquetOutputFormat.WRITER_VERSION, ParquetProperties.WriterVersion.PARQUET_1_0.toString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 78df363ade5c9..2aa5dca847c8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Collections +import org.apache.parquet.column.{Encoding, ParquetProperties} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -534,6 +536,38 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11044 Parquet writer version fixed as version1 ") { + // For dictionary encoding, Parquet changes the encoding types according to its writer + // version. So, this test checks one of the encoding types in order to ensure that + // the file is written with writer version2. + withTempPath { dir => + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Write a Parquet file with writer version2. + hadoopConfiguration.set(ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_2_0.toString) + + // By default, dictionary encoding is enabled from Parquet 1.2.0 but + // it is enabled just in case. + hadoopConfiguration.setBoolean(ParquetOutputFormat.ENABLE_DICTIONARY, true) + val path = s"${dir.getCanonicalPath}/part-r-0.parquet" + sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + .coalesce(1).write.mode("overwrite").parquet(path) + + val blockMetadata = readFooter(new Path(path), hadoopConfiguration).getBlocks.asScala.head + val columnChunkMetadata = blockMetadata.getColumns.asScala.head + + // If the file is written with version2, this should include + // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY + assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + } finally { + // Manually clear the hadoop configuration for other tests. + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary From e388b39d10fc269cdd3d630ea7d4ae80fd0efa97 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 16 Nov 2015 21:59:33 +0800 Subject: [PATCH 030/149] [SPARK-11692][SQL] Support for Parquet logical types, JSON and BSON (embedded types) Parquet supports some JSON and BSON datatypes. They are represented as binary for BSON and string (UTF-8) for JSON internally. I searched a bit and found Apache drill also supports both in this way, [link](https://drill.apache.org/docs/parquet-format/). Author: hyukjinkwon Author: Hyukjin Kwon Closes #9658 from HyukjinKwon/SPARK-11692. --- .../parquet/CatalystSchemaConverter.scala | 3 ++- .../datasources/parquet/ParquetIOSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index f28a18e2756e4..5f9f9083098a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -170,9 +170,10 @@ private[parquet] class CatalystSchemaConverter( case BINARY => originalType match { - case UTF8 | ENUM => StringType + case UTF8 | ENUM | JSON => StringType case null if assumeBinaryIsString => StringType case null => BinaryType + case BSON => BinaryType case DECIMAL => makeDecimalType() case _ => illegalType() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 2aa5dca847c8f..a148facd056a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -259,6 +259,31 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11692 Support for Parquet logical types, JSON and BSON (embedded types)") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required binary a(JSON); + | required binary b(BSON); + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map.empty[String, String].asJava + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + val footer = List( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())) + ).asJava + + ParquetFileWriter.writeMetadataFile(sparkContext.hadoopConfiguration, path, footer) + + val jsonDataType = sqlContext.read.parquet(path.toString).schema(0).dataType + assert(jsonDataType === StringType) + val bsonDataType = sqlContext.read.parquet(path.toString).schema(1).dataType + assert(bsonDataType === BinaryType) + } + } + test("compression codec") { def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { From 0e79604aed116bdcb40e03553a2d103b5b1cdbae Mon Sep 17 00:00:00 2001 From: xin Wu Date: Mon, 16 Nov 2015 08:10:48 -0800 Subject: [PATCH 031/149] [SPARK-11522][SQL] input_file_name() returns "" for external tables When computing partition for non-parquet relation, `HadoopRDD.compute` is used. but it does not set the thread local variable `inputFileName` in `NewSqlHadoopRDD`, like `NewSqlHadoopRDD.compute` does.. Yet, when getting the `inputFileName`, `NewSqlHadoopRDD.inputFileName` is exptected, which is empty now. Adding the setting inputFileName in HadoopRDD.compute resolves this issue. Author: xin Wu Closes #9542 from xwu0226/SPARK-11522. --- .../org/apache/spark/rdd/HadoopRDD.scala | 7 ++ .../sql/hive/execution/HiveUDFSuite.scala | 93 ++++++++++++++++++- 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 0453614f6a1d3..7db583468792e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,6 +213,12 @@ class HadoopRDD[K, V]( val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.inputSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -250,6 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { + SqlNewHadoopRDD.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 5ab477efc4ee0..9deb1a6db15ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataInput, DataOutput} +import java.io.{PrintWriter, File, DataInput, DataOutput} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils @@ -44,7 +45,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with TestHiveSingleton { +class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { import hiveContext.{udf, sql} import hiveContext.implicits._ @@ -348,6 +349,94 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { sqlContext.dropTempTable("testUDF") } + + test("SPARK-11522 select input_file_name from non-parquet table"){ + + withTempDir { tempDir => + + // EXTERNAL OpenCSVSerde table pointing to LOCATION + + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + WITH SERDEPROPERTIES ( + \"separatorChar\" = \",\", + \"quoteChar\" = \"\\\"\", + \"escapeChar\" = \"\\\\\") + LOCATION '$tempDir' + """) + + val answer1 = + sql("SELECT input_file_name() FROM csv_table").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count() + assert(count1 == 2) + sql("DROP TABLE csv_table") + + // EXTERNAL pointing to LOCATION + + sql( + s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int) + ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + LOCATION '$tempDir' + """) + + val answer2 = + sql("SELECT input_file_name() as file FROM external_t5").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count + assert(count2 == 2) + sql("DROP TABLE external_t5") + } + + withTempDir { tempDir => + + // External parquet pointing to LOCATION + + val parquetLocation = tempDir + "/external_parquet" + sql("SELECT 1, 2").write.parquet(parquetLocation) + + sql( + s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int) + STORED AS PARQUET + LOCATION '$parquetLocation' + """) + + val answer3 = + sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0) + assert(answer3.contains("external_parquet")) + + val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count + assert(count3 == 1) + sql("DROP TABLE external_parquet") + } + + // Non-External parquet pointing to /tmp/... + + sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " + + " STORED AS parquet " + + " AS SELECT 1, 2") + + val answer4 = + sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0) + assert(answer4.contains("parquet_tmp")) + + val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count + assert(count4 == 1) + sql("DROP TABLE parquet_tmp") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From 06f1fdba6d1425afddfc1d45a20dbe9bede15e7a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 16 Nov 2015 08:58:40 -0800 Subject: [PATCH 032/149] [SPARK-11752] [SQL] fix timezone problem for DateTimeUtils.getSeconds code snippet to reproduce it: ``` TimeZone.setDefault(TimeZone.getTimeZone("Asia/Shanghai")) val t = Timestamp.valueOf("1900-06-11 12:14:50.789") val us = fromJavaTimestamp(t) assert(getSeconds(us) === t.getSeconds) ``` it will be good to add a regression test for it, but the reproducing code need to change the default timezone, and even we change it back, the `lazy val defaultTimeZone` in `DataTimeUtils` is fixed. Author: Wenchen Fan Closes #9728 from cloud-fan/seconds. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 14 ++++++++------ .../sql/catalyst/util/DateTimeUtilsSuite.scala | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index deff8a5378b92..8fb3f41f1bd6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -395,16 +395,19 @@ object DateTimeUtils { /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ - def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + private def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { microsec + toYearZero * MICROS_PER_DAY } + private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = { + absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ def getHours(microsec: SQLTimestamp): Int = { - val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L - ((localTs / MICROS_PER_SECOND / 3600) % 24).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND / 3600) % 24).toInt } /** @@ -412,8 +415,7 @@ object DateTimeUtils { * microseconds. */ def getMinutes(microsec: SQLTimestamp): Int = { - val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L - ((localTs / MICROS_PER_SECOND / 60) % 60).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND / 60) % 60).toInt } /** @@ -421,7 +423,7 @@ object DateTimeUtils { * microseconds. */ def getSeconds(microsec: SQLTimestamp): Int = { - ((absoluteMicroSecond(microsec) / MICROS_PER_SECOND) % 60).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 64d15e6b910c1..60d45422bc9b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -358,7 +358,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } - test("hours / miniute / seconds") { + test("hours / minutes / seconds") { Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), Timestamp.valueOf("2015-06-11 20:13:40.789"), Timestamp.valueOf("1900-06-11 12:14:50.789"), From b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Nov 2015 09:03:42 -0800 Subject: [PATCH 033/149] [SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder JIRA: https://issues.apache.org/jira/browse/SPARK-11743 RowEncoder doesn't support UserDefinedType now. We should add the support for it. Author: Liang-Chi Hsieh Closes #9712 from viirya/rowencoder-udt. --- .../main/scala/org/apache/spark/sql/Row.scala | 14 +++- .../sql/catalyst/encoders/RowEncoder.scala | 24 +++++- .../sql/catalyst/expressions/objects.scala | 48 +++++------ .../catalyst/encoders/RowEncoderSuite.scala | 82 ++++++++++++++++++- 4 files changed, 139 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ed2fdf9f2f7cf..0f0f200122c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -152,7 +152,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def apply(i: Int): Any = get(i) @@ -177,7 +177,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def get(i: Int): Any @@ -306,7 +306,15 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = getAs[Row](i) + def getStruct(i: Int): Row = { + // Product and Row both are recoginized as StructType in a Row + val t = get(i) + if (t.isInstanceOf[Product]) { + Row.fromTuple(t.asInstanceOf[Product]) + } else { + t.asInstanceOf[Row] + } + } /** * Returns the value at position i. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e0be896bb3548..9bb1602494b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,6 +50,14 @@ object RowEncoder { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, @@ -109,11 +117,16 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => + val method = if (f.dataType.isInstanceOf[StructType]) { + "getStruct" + } else { + "get" + } If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), extractorsFor( - Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), + Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } CreateStruct(convertedFields) @@ -137,6 +150,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) } private def constructorFor(schema: StructType): Expression = { @@ -155,6 +169,14 @@ object RowEncoder { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4f58464221b4b..5cd19de68391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -113,7 +113,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression { override def nullable: Boolean = true - override def children: Seq[Expression] = targetObject :: Nil + override def children: Seq[Expression] = arguments.+:(targetObject) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -343,33 +343,35 @@ case class MapObjects( private lazy val loopAttribute = AttributeReference("loopVar", elementType)() private lazy val completeFunction = function(loopAttribute) + private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case IntegerType => (i: String) => s".getInt($i)" + case LongType => (i: String) => s".getLong($i)" + case FloatType => (i: String) => s".getFloat($i)" + case DoubleType => (i: String) => s".getDouble($i)" + case ByteType => (i: String) => s".getByte($i)" + case ShortType => (i: String) => s".getShort($i)" + case BooleanType => (i: String) => s".getBoolean($i)" + case StringType => (i: String) => s".getUTF8String($i)" + case s: StructType => (i: String) => s".getStruct($i, ${s.size})" + case a: ArrayType => (i: String) => s".getArray($i)" + case _: MapType => (i: String) => s".getMap($i)" + case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + } + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) - case ArrayType(s: StructType, _) => - (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) - case ArrayType(a: ArrayType, _) => - (".numElements()", (i: String) => s".getArray($i)", true) - case ArrayType(IntegerType, _) => - (".numElements()", (i: String) => s".getInt($i)", true) - case ArrayType(LongType, _) => - (".numElements()", (i: String) => s".getLong($i)", true) - case ArrayType(FloatType, _) => - (".numElements()", (i: String) => s".getFloat($i)", true) - case ArrayType(DoubleType, _) => - (".numElements()", (i: String) => s".getDouble($i)", true) - case ArrayType(ByteType, _) => - (".numElements()", (i: String) => s".getByte($i)", true) - case ArrayType(ShortType, _) => - (".numElements()", (i: String) => s".getShort($i)", true) - case ArrayType(BooleanType, _) => - (".numElements()", (i: String) => s".getBoolean($i)", true) - case ArrayType(StringType, _) => - (".numElements()", (i: String) => s".getUTF8String($i)", false) - case ArrayType(_: MapType, _) => - (".numElements()", (i: String) => s".getMap($i)", false) + case ArrayType(t, _) => + val (sqlType, primitiveElement) = t match { + case m: MapType => (m, false) + case s: StructType => (s, false) + case s: StringType => (s, false) + case udt: UserDefinedType[_] => (udt.sqlType, false) + case o => (o, true) + } + (".numElements()", itemAccessorMethod(sqlType), primitiveElement) } override def nullable: Boolean = true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index e8301e8e06b52..c868ddec1bab2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt + override def equals(that: Any): Boolean = { + if (that.isInstanceOf[ExamplePoint]) { + val e = that.asInstanceOf[ExamplePoint] + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + } else { + false + } + } +} + +/** + * User-defined type for [[ExamplePoint]]. + */ +class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + + override def serialize(obj: Any): GenericArrayData = { + obj match { + case p: ExamplePoint => + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this +} + class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() @@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) - .add("timestamp", TimestampType)) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT, false)) encodeDecodeTest( new StructType() @@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite { .add("structOfArray", new StructType().add("array", arrayOfString)) .add("structOfMap", new StructType().add("map", mapOfString)) .add("structOfArrayAndMap", - new StructType().add("array", arrayOfString).add("map", mapOfString))) + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + test(s"encode/decode: arrayOfUDT") { + val schema = new StructType() + .add("arrayOfUDT", arrayOfUDT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) + } + + test(s"encode/decode: Product") { + val schema = new StructType() + .add("structAsProduct", + new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType)) + + val encoder = RowEncoder(schema) + + val input: Row = Row((100, "test", 0.123)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getStruct(0) == convertedBack.getStruct(0)) + } private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { From de5e531d337075fd849437e88846873bca8685e6 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 16 Nov 2015 11:21:17 -0800 Subject: [PATCH 034/149] [SPARK-11731][STREAMING] Enable batching on Driver WriteAheadLog by default Using batching on the driver for the WriteAheadLog should be an improvement for all environments and use cases. Users will be able to scale to much higher number of receivers with the BatchedWriteAheadLog. Therefore we should turn it on by default, and QA it in the QA period. I've also added some tests to make sure the default configurations are correct regarding recent additions: - batching on by default - closeFileAfterWrite off by default - parallelRecovery off by default Author: Burak Yavuz Closes #9695 from brkyvz/enable-batch-wal. --- .../streaming/util/WriteAheadLogUtils.scala | 2 +- .../streaming/JavaWriteAheadLogSuite.java | 1 + .../streaming/ReceivedBlockTrackerSuite.scala | 9 +++++-- .../streaming/util/WriteAheadLogSuite.scala | 24 ++++++++++++++++++- .../util/WriteAheadLogUtilsSuite.scala | 19 ++++++++++++--- 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 731a369fc92c0..7f9e2c9734970 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -67,7 +67,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { - isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = false) + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = true) } /** diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 175b8a496b4e5..09b5f8ed03279 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -108,6 +108,7 @@ public void close() { public void testCustomWAL() { SparkConf conf = new SparkConf(); conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false"); WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 7db17abb7947c..081f5a1c93e6e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -330,8 +330,13 @@ class ReceivedBlockTrackerSuite : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq - }.map { byteBuffer => - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.flatMap { byteBuffer => + val validBuffer = if (WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) { + Utils.deserialize[Array[Array[Byte]]](byteBuffer.array()).map(ByteBuffer.wrap) + } else { + Array(byteBuffer) + } + validBuffer.map(b => Utils.deserialize[ReceivedBlockTrackerLogEvent](b.array())) }.toList } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4273fd7dda8be..7f80d6ecdbbb5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -20,7 +20,7 @@ import java.io._ import java.nio.ByteBuffer import java.util.{Iterator => JIterator} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, CountDownLatch, ThreadPoolExecutor} +import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -190,6 +190,28 @@ abstract class CommonWriteAheadLogTests( } assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") } + + test(testPrefix + "parallel recovery not enabled if closeFileAfterWrite = false") { + // write some data + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten + + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + // create iterator but don't materialize it + val readData = wal.readAll().asScala.map(byteBufferToString) + wal.close() + if (closeFileAfterWrite) { + // the threadpool is shutdown by the wal.close call above, therefore we shouldn't be able + // to materialize the iterator with parallel recovery + intercept[RejectedExecutionException](readData.toArray) + } else { + assert(readData.toSeq === writtenData) + } + } } class FileBasedWriteAheadLogSuite diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala index 9152728191ea1..bfc5b0cf60fb1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -56,19 +56,19 @@ class WriteAheadLogUtilsSuite extends SparkFunSuite { test("log selection and creation") { val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf, isBatched = true) assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) // Verify setting driver WAL class val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](driverWALConf) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) // Verify setting receiver WAL class val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) // Verify setting receiver WAL class with 1-arg constructor @@ -104,6 +104,19 @@ class WriteAheadLogUtilsSuite extends SparkFunSuite { assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) } + + test("batching is enabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) + // batching is not valid for receiver WALs + assert(!WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = false)) + } + + test("closeFileAfterWrite is disabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = true)) + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = false)) + } } object WriteAheadLogUtilsSuite { From ace0db47141ffd457c2091751038fc291f6d5a8b Mon Sep 17 00:00:00 2001 From: Daniel Jalova Date: Mon, 16 Nov 2015 11:29:27 -0800 Subject: [PATCH 035/149] [SPARK-6328][PYTHON] Python API for StreamingListener Author: Daniel Jalova Closes #9186 from djalova/SPARK-6328. --- python/pyspark/streaming/__init__.py | 3 +- python/pyspark/streaming/context.py | 8 ++ python/pyspark/streaming/listener.py | 75 +++++++++++ python/pyspark/streaming/tests.py | 126 +++++++++++++++++- .../api/java/JavaStreamingListener.scala | 76 +++++++++++ 5 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/streaming/listener.py diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py index d2644a1d4ffab..66e8f8ef001e3 100644 --- a/python/pyspark/streaming/__init__.py +++ b/python/pyspark/streaming/__init__.py @@ -17,5 +17,6 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream +from pyspark.streaming.listener import StreamingListener -__all__ = ['StreamingContext', 'DStream'] +__all__ = ['StreamingContext', 'DStream', 'StreamingListener'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 8be56c9915265..1388b6d044e04 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -363,3 +363,11 @@ def union(self, *dstreams): first = dstreams[0] jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) + + def addStreamingListener(self, streamingListener): + """ + Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + receiving system events related to streaming. + """ + self._jssc.addStreamingListener(self._jvm.JavaStreamingListenerWrapper( + self._jvm.PythonStreamingListenerWrapper(streamingListener))) diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py new file mode 100644 index 0000000000000..b830797f5c0a0 --- /dev/null +++ b/python/pyspark/streaming/listener.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["StreamingListener"] + + +class StreamingListener(object): + + def __init__(self): + pass + + def onReceiverStarted(self, receiverStarted): + """ + Called when a receiver has been started + """ + pass + + def onReceiverError(self, receiverError): + """ + Called when a receiver has reported an error + """ + pass + + def onReceiverStopped(self, receiverStopped): + """ + Called when a receiver has been stopped + """ + pass + + def onBatchSubmitted(self, batchSubmitted): + """ + Called when a batch of jobs has been submitted for processing. + """ + pass + + def onBatchStarted(self, batchStarted): + """ + Called when processing of a batch of jobs has started. + """ + pass + + def onBatchCompleted(self, batchCompleted): + """ + Called when processing of a batch of jobs has completed. + """ + pass + + def onOutputOperationStarted(self, outputOperationStarted): + """ + Called when processing of a job of a batch has started. + """ + pass + + def onOutputOperationCompleted(self, outputOperationCompleted): + """ + Called when processing of a job of a batch has completed + """ + pass + + class Java: + implements = ["org.apache.spark.streaming.api.java.PythonStreamingListener"] diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 6ee864d8d3da6..2983028413bb8 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -48,6 +48,7 @@ from pyspark.streaming.flume import FlumeUtils from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.streaming.listener import StreamingListener class PySparkStreamingTestCase(unittest.TestCase): @@ -403,6 +404,128 @@ def func(dstream): self._test_func(input, func, expected) +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + + self.wait_for(batchInfosCompleted, 4) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + + class WindowFunctionTests(PySparkStreamingTestCase): timeout = 15 @@ -1308,7 +1431,8 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests] + KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests, + StreamingListenerTests] if kinesis_jar_present is True: testcases.append(KinesisStreamTests) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala index 34429074fe804..7bfd6bd5af759 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -18,6 +18,82 @@ package org.apache.spark.streaming.api.java import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler.StreamingListener + +private[streaming] trait PythonStreamingListener{ + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted) { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError) { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped) { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted) { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted) { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted) { } +} + +private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamingListener) + extends JavaStreamingListener { + + /** Called when a receiver has been started */ + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + listener.onReceiverStarted(receiverStarted) + } + + /** Called when a receiver has reported an error */ + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + listener.onReceiverError(receiverError) + } + + /** Called when a receiver has been stopped */ + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + listener.onReceiverStopped(receiverStopped) + } + + /** Called when a batch of jobs has been submitted for processing. */ + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + listener.onBatchSubmitted(batchSubmitted) + } + + /** Called when processing of a batch of jobs has started. */ + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + listener.onBatchStarted(batchStarted) + } + + /** Called when processing of a batch of jobs has completed. */ + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + listener.onBatchCompleted(batchCompleted) + } + + /** Called when processing of a job of a batch has started. */ + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + listener.onOutputOperationStarted(outputOperationStarted) + } + + /** Called when processing of a job of a batch has completed. */ + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + listener.onOutputOperationCompleted(outputOperationCompleted) + } +} /** * A listener interface for receiving information about an ongoing streaming computation. From 24477d2705bcf2a851acc241deb8376c5450dc73 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 16 Nov 2015 11:43:18 -0800 Subject: [PATCH 036/149] [SPARK-11718][YARN][CORE] Fix explicitly killed executor dies silently issue Currently if dynamic allocation is enabled, explicitly killing executor will not get response, so the executor metadata is wrong in driver side. Which will make dynamic allocation on Yarn fail to work. The problem is `disableExecutor` returns false for pending killing executors when `onDisconnect` is detected, so no further implementation is done. One solution is to bypass these explicitly killed executors to use `super.onDisconnect` to remove executor. This is simple. Another solution is still querying the loss reason for these explicitly kill executors. Since executor may get killed and informed in the same AM-RM communication, so current way of adding pending loss reason request is not worked (container complete is already processed), here we should store this loss reason for later query. Here this PR chooses solution 2. Please help to review. vanzin I think this part is changed by you previously, would you please help to review? Thanks a lot. Author: jerryshao Closes #9684 from jerryshao/SPARK-11718. --- .../spark/scheduler/TaskSchedulerImpl.scala | 1 + .../CoarseGrainedSchedulerBackend.scala | 6 ++-- .../spark/deploy/yarn/YarnAllocator.scala | 30 +++++++++++++++---- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 43d7d80b7aae1..5f136690f456c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -473,6 +473,7 @@ private[spark] class TaskSchedulerImpl( // If the host mapping still exists, it means we don't know the loss reason for the // executor. So call removeExecutor() to update tasks running on that executor when // the real loss reason is finally known. + logError(s"Actual reason for lost executor $executorId: ${reason.message}") removeExecutor(executorId, reason) case None => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f71d98feac050..3373caf0d15eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -269,7 +269,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. * - * @return Whether executor was alive. + * @return Whether executor should be disabled */ protected def disableExecutor(executorId: String): Boolean = { val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { @@ -277,7 +277,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorsPendingLossReason += executorId true } else { - false + // Returns true for explicitly killed executors, we also need to get pending loss reasons; + // For others return false. + executorsPendingToRemove.contains(executorId) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4d9e777cb4134..7e39c3ea56af3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} @@ -96,6 +96,10 @@ private[yarn] class YarnAllocator( // was lost. private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + // Maintain loss reasons for already released executors, it will be added when executor loss + // reason is got from AM-RM call, and be removed after querying this loss reason. + private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason] + // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] @@ -202,8 +206,7 @@ private[yarn] class YarnAllocator( */ def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.remove(executorId).get - containerIdToExecutorId.remove(container.getId) + val container = executorIdToContainer.get(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -514,9 +517,18 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - pendingLossReasonRequests.remove(eid).foreach { pendingRequests => - // Notify application of executor loss reasons so it can decide whether it should abort - pendingRequests.foreach(_.reply(exitReason)) + pendingLossReasonRequests.remove(eid) match { + case Some(pendingRequests) => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + + case None => + // We cannot find executor for pending reasons. This is because completed container + // is processed before querying pending result. We should store it for later query. + // This is usually happened when explicitly killing a container, the result will be + // returned in one AM-RM communication. So query RPC will be later than this completed + // container process. + releasedExecutorLossReasons.put(eid, exitReason) } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) @@ -538,8 +550,14 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(eid)) { pendingLossReasonRequests .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else if (releasedExecutorLossReasons.contains(eid)) { + // Executor is already released explicitly before getting the loss reason, so directly send + // the pre-stored lost reason + context.reply(releasedExecutorLossReasons.remove(eid).get) } else { logWarning(s"Tried to get the loss reason for non-existent executor $eid") + context.sendFailure( + new SparkException(s"Fail to find loss reason for non-existent executor $eid")) } } From b1a9662623951079e80bd7498e064c4cae4977e9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 16 Nov 2015 12:45:34 -0800 Subject: [PATCH 037/149] [SPARK-11754][SQL] consolidate `ExpressionEncoder.tuple` and `Encoders.tuple` These 2 are very similar, we can consolidate them into one. Also add tests for it and fix a bug. Author: Wenchen Fan Closes #9729 from cloud-fan/tuple. --- .../scala/org/apache/spark/sql/Encoder.scala | 95 ++++------------ .../catalyst/encoders/ExpressionEncoder.scala | 104 ++++++++++-------- .../encoders/ProductEncoderSuite.scala | 29 +++++ 3 files changed, 108 insertions(+), 120 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 5f619d6c339e3..c8b017e251637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.types.StructType /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -49,83 +47,34 @@ object Encoders { def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) - def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { - tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2)]] + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) } def tuple[T1, T2, T3]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) } def tuple[T1, T2, T3, T4]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) } def tuple[T1, T2, T3, T4, T5]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4], - enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] - } - - private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - assert(encoders.length > 1) - // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty)) - - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) - }) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val extractExpressions = encoders.map { - case e if e.flat => e.toRowExpressions.head - case other => CreateStruct(other.toRowExpressions) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t: ObjectType, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) - } - } - - val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.fromRowExpression.transform { - case b: BoundReference => b.copy(ordinal = index) - } - } else { - enc.fromRowExpression.transformUp { - case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) - } - } - } - - val constructExpression = - NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) - - new ExpressionEncoder[Any]( - schema, - flat = false, - extractExpressions, - constructExpression, - ClassTag(cls)) + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0d3e4aafb0af4..9a1a8f5cbbdc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -67,47 +67,77 @@ object ExpressionEncoder { def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { encoders.foreach(_.assertUnresolved()) - val schema = - StructType( - encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) - }) + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - // Rebind the encoders to the nested schema. - val newConstructExpressions = encoders.zipWithIndex.map { - case (e, i) if !e.flat => e.nested(i).fromRowExpression - case (e, i) => e.shift(i).fromRowExpression + val toRowExpressions = encoders.map { + case e if e.flat => e.toRowExpressions.head + case other => CreateStruct(other.toRowExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } } - val constructExpression = - NewInstance(cls, newConstructExpressions, false, ObjectType(cls)) - - val input = BoundReference(0, ObjectType(cls), false) - val extractExpressions = encoders.zipWithIndex.map { - case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp { - case b: BoundReference => - Invoke(input, s"_${i + 1}", b.dataType, Nil) - })) - case (e, i) => e.toRowExpressions.head transformUp { - case b: BoundReference => - Invoke(input, s"_${i + 1}", b.dataType, Nil) + val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.fromRowExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + val input = BoundReference(index, enc.schema, nullable = true) + enc.fromRowExpression.transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(input, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + } } } + val fromRowExpression = + NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + new ExpressionEncoder[Any]( schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) + flat = false, + toRowExpressions, + fromRowExpression, + ClassTag(cls)) } - /** A helper for producing encoders of Tuple2 from other encoders. */ def tuple[T1, T2]( e1: ExpressionEncoder[T1], e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = - tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] + tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] + + def tuple[T1, T2, T3]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = + tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + + def tuple[T1, T2, T3, T4]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = + tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + + def tuple[T1, T2, T3, T4, T5]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4], + e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = + tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] } /** @@ -208,26 +238,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Returns a copy of this encoder where the expressions used to create an object given an - * input row have been modified to pull the object out from a nested struct, instead of the - * top level fields. - */ - private def nested(i: Int): ExpressionEncoder[T] = { - // We don't always know our input type at this point since it might be unresolved. - // We fill in null and it will get unbound to the actual attribute at this position. - val input = BoundReference(i, NullType, nullable = true) - copy(fromRowExpression = fromRowExpression transformUp { - case u: Attribute => - UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference => - GetStructField( - input, - StructField(s"i[${b.ordinal}]", b.dataType), - b.ordinal) - }) - } - protected val attrs = toRowExpressions.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index fda978e7055ea..bc539d62c537d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite { productTest(("Seq[Seq[(Int, Int)]]", Seq(Seq((1, 2))))) + encodeDecodeTest( + 1 -> 10L, + ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), + "tuple with 2 flat encoders") + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), + "tuple with 2 product encoders") + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), + "tuple with flat encoder and product encoder") + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), + "tuple with product encoder and flat encoder") + + encodeDecodeTest( + (1, (10, 100L)), + { + val intEnc = FlatEncoder[Int] + val longEnc = FlatEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + }, + "nested tuple encoder") + private def productTest[T <: Product : TypeTag](input: T): Unit = { encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) } From 985b38dd2fa5d8f1e23f1c420ce6262e7e3ed181 Mon Sep 17 00:00:00 2001 From: Zee Chen Date: Mon, 16 Nov 2015 14:21:28 -0800 Subject: [PATCH 038/149] [SPARK-11390][SQL] Query plan with/without filterPushdown indistinguishable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ishable Propagate pushed filters to PhyicalRDD in DataSourceStrategy.apply Author: Zee Chen Closes #9679 from zeocio/spark-11390. --- .../apache/spark/sql/execution/ExistingRDD.scala | 6 ++++-- .../execution/datasources/DataSourceStrategy.scala | 6 ++++-- .../apache/spark/sql/execution/PlannerSuite.scala | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 8b41d3d3d892e..62620ec642c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -106,7 +106,9 @@ private[sql] object PhysicalRDD { def createFromDataSource( output: Seq[Attribute], rdd: RDD[InternalRow], - relation: BaseRelation): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString, relation.isInstanceOf[HadoopFsRelation]) + relation: BaseRelation, + extraInformation: String = ""): PhysicalRDD = { + PhysicalRDD(output, rdd, relation.toString + extraInformation, + relation.isInstanceOf[HadoopFsRelation]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9bbbfa7c77cba..544d5eccec037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -315,6 +315,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + val pushedFiltersString = pushedFilters.mkString(" PushedFilter: [", ",", "] ") + if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { @@ -332,7 +334,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) + relation.relation, pushedFiltersString) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -342,7 +344,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) + relation.relation, pushedFiltersString) execution.Project( projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index be53ec3e271c6..dfec139985f73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -160,6 +160,20 @@ class PlannerSuite extends SharedSQLContext { } } + test("SPARK-11390 explain should print PushedFilters of PhysicalRDD") { + withTempPath { file => + val path = file.getCanonicalPath + testData.write.parquet(path) + val df = sqlContext.read.parquet(path) + sqlContext.registerDataFrameAsTable(df, "testPushed") + + withTempTable("testPushed") { + val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan + assert(exp.toString.contains("PushedFilter: [EqualTo(key,15)]")) + } + } + } + test("efficient limit -> project -> sort") { { val query = From 3c025087b58f475a9bcb5c8f4b2b2df804915b2b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 16 Nov 2015 14:50:38 -0800 Subject: [PATCH 039/149] Revert "[SPARK-11271][SPARK-11016][CORE] Use Spark BitSet instead of RoaringBitmap to reduce memory usage" This reverts commit e209fa271ae57dc8849f8b1241bf1ea7d6d3d62c. --- core/pom.xml | 4 ++ .../apache/spark/scheduler/MapStatus.scala | 13 ++--- .../spark/serializer/KryoSerializer.scala | 10 +++- .../apache/spark/util/collection/BitSet.scala | 28 ++--------- .../serializer/KryoSerializerSuite.scala | 6 +++ .../spark/util/collection/BitSetSuite.scala | 49 ------------------- pom.xml | 5 ++ 7 files changed, 33 insertions(+), 82 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 7e1205a076f26..37e3f168ab374 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -177,6 +177,10 @@ net.jpountz.lz4 lz4 + + org.roaringbitmap + RoaringBitmap + commons-net commons-net diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 180c8d1827e13..1efce124c0a6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import org.roaringbitmap.RoaringBitmap + import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.collection.BitSet import org.apache.spark.util.Utils /** @@ -132,7 +133,7 @@ private[spark] class CompressedMapStatus( private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, - private[this] var emptyBlocks: BitSet, + private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long) extends MapStatus with Externalizable { @@ -145,7 +146,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { - if (emptyBlocks.get(reduceId)) { + if (emptyBlocks.contains(reduceId)) { 0 } else { avgSize @@ -160,7 +161,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - emptyBlocks = new BitSet + emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() } @@ -176,15 +177,15 @@ private[spark] object HighlyCompressedMapStatus { // From a compression standpoint, it shouldn't matter whether we track empty or non-empty // blocks. From a performance standpoint, we benefit from tracking empty blocks because // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. + val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length - val emptyBlocks = new BitSet(totalNumBlocks) while (i < totalNumBlocks) { var size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 totalSize += size } else { - emptyBlocks.set(i) + emptyBlocks.add(i) } i += 1 } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index bc51d4f2820c8..c5195c1143a8f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -30,6 +30,7 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} +import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast @@ -38,7 +39,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} -import org.apache.spark.util.collection.{BitSet, CompactBuffer} +import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -362,7 +363,12 @@ private[serializer] object KryoSerializer { classOf[StorageLevel], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], - classOf[BitSet], + classOf[RoaringBitmap], + classOf[RoaringArray], + classOf[RoaringArray.Element], + classOf[Array[RoaringArray.Element]], + classOf[ArrayContainer], + classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 85c5bdbfcebc0..7ab67fc3a2de9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,21 +17,14 @@ package org.apache.spark.util.collection -import java.io.{Externalizable, ObjectInput, ObjectOutput} - -import org.apache.spark.util.{Utils => UUtils} - - /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(private[this] var numBits: Int) extends Externalizable { +class BitSet(numBits: Int) extends Serializable { - private var words = new Array[Long](bit2words(numBits)) - private def numWords = words.length - - def this() = this(0) + private val words = new Array[Long](bit2words(numBits)) + private val numWords = words.length /** * Compute the capacity (number of bits) that can be represented @@ -237,19 +230,4 @@ class BitSet(private[this] var numBits: Int) extends Externalizable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 - - override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException { - out.writeInt(numBits) - words.foreach(out.writeLong(_)) - } - - override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException { - numBits = in.readInt() - words = new Array[Long](bit2words(numBits)) - var index = 0 - while (index < words.length) { - words(index) = in.readLong() - index += 1 - } - } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index afe2e80358ca0..e428414cf6e85 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -322,6 +322,12 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrationRequired", "true") + // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 + // values, and they use a bitmap (dense) if they have more than 4096 values, and an + // array (sparse) if they use less. So we just create two cases, one sparse and one dense. + // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly + // empty blocks + val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index b0db0988eeaab..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,10 +17,7 @@ package org.apache.spark.util.collection -import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream} - import org.apache.spark.SparkFunSuite -import org.apache.spark.util.{Utils => UUtils} class BitSetSuite extends SparkFunSuite { @@ -155,50 +152,4 @@ class BitSetSuite extends SparkFunSuite { assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(86) === -1) } - - test("read and write externally") { - val tempDir = UUtils.createTempDir() - val outputFile = File.createTempFile("bits", null, tempDir) - - val fos = new FileOutputStream(outputFile) - val oos = new ObjectOutputStream(fos) - - // Create BitSet - val setBits = Seq(0, 9, 1, 10, 90, 96) - val bitset = new BitSet(100) - - for (i <- 0 until 100) { - assert(!bitset.get(i)) - } - - setBits.foreach(i => bitset.set(i)) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset.get(i)) - } else { - assert(!bitset.get(i)) - } - } - assert(bitset.cardinality() === setBits.size) - - bitset.writeExternal(oos) - oos.close() - - val fis = new FileInputStream(outputFile) - val ois = new ObjectInputStream(fis) - - // Read BitSet from the file - val bitset2 = new BitSet(0) - bitset2.readExternal(ois) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset2.get(i)) - } else { - assert(!bitset2.get(i)) - } - } - assert(bitset2.cardinality() === setBits.size) - } } diff --git a/pom.xml b/pom.xml index 01afa80617891..2a8a445057174 100644 --- a/pom.xml +++ b/pom.xml @@ -634,6 +634,11 @@ + + org.roaringbitmap + RoaringBitmap + 0.4.5 + commons-net commons-net From bcea0bfda66a30ee86790b048de5cb47b4d0b32f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 16 Nov 2015 15:06:06 -0800 Subject: [PATCH 040/149] [SPARK-11742][STREAMING] Add the failure info to the batch lists screen shot 2015-11-13 at 9 57 43 pm Author: Shixiong Zhu Closes #9711 from zsxwing/failure-info. --- .../spark/streaming/ui/AllBatchesTable.scala | 61 +++++++++++++++++-- .../apache/spark/streaming/ui/BatchPage.scala | 49 ++------------- .../apache/spark/streaming/ui/UIUtils.scala | 60 ++++++++++++++++++ 3 files changed, 120 insertions(+), 50 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 125cafd41b8af..d33972342731d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -33,6 +33,22 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) {SparkUIUtils.tooltip("Time taken to process all jobs of a batch", "top")} } + /** + * Return the first failure reason if finding in the batches. + */ + protected def getFirstFailureReason(batches: Seq[BatchUIData]): Option[String] = { + batches.flatMap(_.outputOperations.flatMap(_._2.failureReason)).headOption + } + + protected def getFirstFailureTableCell(batch: BatchUIData): Seq[Node] = { + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + firstFailureReason.map { failureReason => + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan = 1, includeFirstLineInExpandDetails = false) + }.getOrElse(-) + } + protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) @@ -97,9 +113,17 @@ private[ui] class ActiveBatchTable( waitingBatches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { + private val firstFailureReason = getFirstFailureReason(runningBatches) + override protected def columns: Seq[Node] = super.columns ++ { Output Ops: Succeeded/Total - Status + Status ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } } override protected def renderRows: Seq[Node] = { @@ -110,20 +134,41 @@ private[ui] class ActiveBatchTable( } private def runningBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing ++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued++ { + if (firstFailureReason.nonEmpty) { + // Waiting batches have not run yet, so must have no failure reasons. + - + } else { + Nil + } + } } } private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("completed-batches-table", batchInterval) { + private val firstFailureReason = getFirstFailureReason(batches) + override protected def columns: Seq[Node] = super.columns ++ { Total Delay {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} - Output Ops: Succeeded/Total + Output Ops: Succeeded/Total ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } } override protected def renderRows: Seq[Node] = { @@ -138,6 +183,12 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: {formattedTotalDelay} - } ++ createOutputOperationProgressBar(batch) + } ++ createOutputOperationProgressBar(batch)++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 2ed925572826e..bc1711930d3ac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -149,7 +149,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { total = sparkJob.numTasks - sparkJob.numSkippedTasks) } - {failureReasonCell(lastFailureReason, rowspan = 1)} + {UIUtils.failureReasonCell(lastFailureReason)} } @@ -245,48 +245,6 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } - private def failureReasonCell( - failureReason: String, - rowspan: Int, - includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { - val isMultiline = failureReason.indexOf('\n') >= 0 - // Display the first line by default - val failureReasonSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - failureReason.substring(0, failureReason.indexOf('\n')) - } else { - failureReason - }) - val failureDetails = - if (isMultiline && !includeFirstLineInExpandDetails) { - // Skip the first line - failureReason.substring(failureReason.indexOf('\n') + 1) - } else { - failureReason - } - val details = if (isMultiline) { - // scalastyle:off - - +details - ++ - - // scalastyle:on - } else { - "" - } - - if (rowspan == 1) { - {failureReasonSummary}{details} - } else { - - {failureReasonSummary}{details} - - } - } - private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { sparkListener.activeJobs.get(sparkJobId).orElse { sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { @@ -434,8 +392,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def outputOpStatusCell(outputOp: OutputOperationUIData, rowspan: Int): Seq[Node] = { outputOp.failureReason match { case Some(failureReason) => - val failureReasonForUI = generateOutputOperationStatusForUI(failureReason) - failureReasonCell(failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) case None => if (outputOp.endTime.isEmpty) { - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 86cfb1fa47370..d89f7ad3e16b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,6 +17,10 @@ package org.apache.spark.streaming.ui +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + import java.text.SimpleDateFormat import java.util.TimeZone import java.util.concurrent.TimeUnit @@ -124,4 +128,60 @@ private[streaming] object UIUtils { } } } + + def createOutputOperationFailureForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.Spark")) { + // SparkException or SparkDriverExecutionException + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.size + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + + def failureReasonCell( + failureReason: String, + rowspan: Int = 1, + includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val failureDetails = + if (isMultiline && !includeFirstLineInExpandDetails) { + // Skip the first line + failureReason.substring(failureReason.indexOf('\n') + 1) + } else { + failureReason + } + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + if (rowspan == 1) { + {failureReasonSummary}{details} + } else { + + {failureReasonSummary}{details} + + } + } } From 31296628ac7cd7be71e0edca335dc8604f62bb47 Mon Sep 17 00:00:00 2001 From: Bartlomiej Alberski Date: Mon, 16 Nov 2015 15:14:38 -0800 Subject: [PATCH 041/149] [SPARK-11553][SQL] Primitive Row accessors should not convert null to default value Invocation of getters for type extending AnyVal returns default value (if field value is null) instead of throwing NPE. Please check comments for SPARK-11553 issue for more details. Author: Bartlomiej Alberski Closes #9642 from alberskib/bugfix/SPARK-11553. --- .../main/scala/org/apache/spark/sql/Row.scala | 32 ++++++++++++----- .../scala/org/apache/spark/sql/RowTest.scala | 20 +++++++++++ .../local/NestedLoopJoinNodeSuite.scala | 36 +++++++++++-------- 3 files changed, 65 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0f0f200122c34..b14c66cc5ac88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -191,7 +191,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -199,7 +199,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte = getAs[Byte](i) + def getByte(i: Int): Byte = getAnyValAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -207,7 +207,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short = getAs[Short](i) + def getShort(i: Int): Short = getAnyValAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -215,7 +215,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int = getAs[Int](i) + def getInt(i: Int): Int = getAnyValAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -223,7 +223,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long = getAs[Long](i) + def getLong(i: Int): Long = getAnyValAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -232,7 +232,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float = getAs[Float](i) + def getFloat(i: Int): Float = getAnyValAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -240,13 +240,12 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double = getAs[Double](i) + def getDouble(i: Int): Double = getAnyValAs[Double](i) /** * Returns the value at position i as a String object. * * @throws ClassCastException when data type does not match. - * @throws NullPointerException when value is null. */ def getString(i: Int): String = getAs[String](i) @@ -318,6 +317,8 @@ trait Row extends Serializable { /** * Returns the value at position i. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws ClassCastException when data type does not match. */ @@ -325,6 +326,8 @@ trait Row extends Serializable { /** * Returns the value of a given fieldName. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -344,6 +347,8 @@ trait Row extends Serializable { /** * Returns a Map(name -> value) for the requested fieldNames + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -458,4 +463,15 @@ trait Row extends Serializable { * start, end, and separator strings. */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + private def getAnyValAs[T <: AnyVal](i: Int): T = + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + else getAs[T](i) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 01ff84cb56054..5c22a72192541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers { StructField("col2", StringType) :: StructField("col3", IntegerType) :: Nil) val values = Array("value1", "value2", 1) + val valuesWithoutCol3 = Array[Any](null, "value2", null) val sampleRow: Row = new GenericRowWithSchema(values, schema) + val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema) val noSchemaRow: Row = new GenericRow(values) describe("Row (without schema)") { @@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers { ) sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } + + it("getValuesMap() retrieves null value on non AnyVal Type") { + val expected = Map( + "col1" -> null, + "col2" -> "value2" + ) + sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected + } + + it("getAs() on type extending AnyVal throws an exception when accessing field that is null") { + intercept[NullPointerException] { + sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3")) + } + } + + it("getAs() on type extending AnyVal does not throw exception when value is null"){ + sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null + } } describe("row equals") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 252f7cc8971f4..45df2ea6552d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + // ( + // id, name, + // id, nickname + // ) + ( + Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)), + Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)) + ) } assert(actualOutput.toSet === expectedOutput.toSet) } @@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { private def generateExpectedOutput( leftInput: Array[(Int, String)], rightInput: Array[(Int, String)], - joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = { joinType match { case LeftOuter => val rightInputMap = rightInput.toMap leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) } case RightOuter => val leftInputMap = leftInput.toMap rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) } case FullOuter => val leftInputMap = leftInput.toMap val rightInputMap = rightInput.toMap val leftOutput = leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) } val rightOutput = rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) } (leftOutput ++ rightOutput).distinct From 75ee12f09c2645c1ad682764d512965f641eb5c2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 16 Nov 2015 15:22:12 -0800 Subject: [PATCH 042/149] [SPARK-8658][SQL] AttributeReference's equals method compares all the members This fix is to change the equals method to check all of the specified fields for equality of AttributeReference. Author: gatorsmile Closes #9216 from gatorsmile/namedExpressEqual. --- .../sql/catalyst/expressions/namedExpressions.scala | 4 +++- .../sql/catalyst/plans/logical/basicOperators.scala | 10 +++++----- .../sql/catalyst/plans/physical/partitioning.scala | 12 ++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f80bcfcb0b0bf..e3daddace241d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -194,7 +194,9 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => + name == ar.name && dataType == ar.dataType && nullable == ar.nullable && + metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e2b97b27a6c2c..45630a591d349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -244,12 +244,12 @@ private[sql] object Expand { */ private def buildNonSelectExprSet( bitmask: Int, - exprs: Seq[Expression]): OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + exprs: Seq[Expression]): ArrayBuffer[Expression] = { + val set = new ArrayBuffer[Expression](2) var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + if (((bitmask >> bit) & 1) == 0) set += exprs(bit) bit -= 1 } @@ -279,7 +279,7 @@ private[sql] object Expand { (child.output :+ gid).map(expr => expr transformDown { // TODO this causes a problem when a column is used both for grouping and aggregation. - case x: Expression if nonSelectedGroupExprSet.contains(x) => + case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, expr.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 86b9417477ba3..f6fb31a2af594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -235,17 +235,17 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - expressions.toSet.subsetOf(requiredClustering.toSet) + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } @@ -276,17 +276,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } } From fd14936be7beff543dbbcf270f2f9749f7a803c4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 16 Nov 2015 15:32:49 -0800 Subject: [PATCH 043/149] [SPARK-11625][SQL] add java test for typed aggregate Author: Wenchen Fan Closes #9591 from cloud-fan/agg-test. --- .../spark/api/java/function/Function.java | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 34 ++++++++++- .../spark/sql/expressions/Aggregator.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 56 +++++++++++++++++++ .../spark/sql/DatasetAggregatorSuite.scala | 7 +-- 5 files changed, 92 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java index d00551bb0add6..b9d9777a75651 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java @@ -25,5 +25,5 @@ * when mapping RDDs of other types. */ public interface Function extends Serializable { - public R call(T1 v1) throws Exception; + R call(T1 v1) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index ebcf4c8bfe7e6..467cd42b9b8dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -145,9 +145,37 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - // To ensure valid overloading. - protected def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(expr, exprs: _*) + /** + * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. + * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg(max("age"), sum("expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(max("age"), sum("expense")); + * }}} + * + * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = + groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + + private def withEncoder(c: Column): Column = c match { + case tc: TypedColumn[_, _] => + tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + case _ => c + } /** * Internal helper function for building typed aggregations that return tuples. For simplicity diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 360c9a5bc15e7..72610e735f782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * @tparam B The type of the intermediate value of the reduction. * @tparam C The type of the final result. */ -abstract class Aggregator[-A, B, C] { +abstract class Aggregator[-A, B, C] extends Serializable { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ def zero: B diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index eb6fa1e72e27b..d9b22506fbd3b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import static org.apache.spark.sql.functions.*; @@ -381,4 +382,59 @@ public void testNestedTupleEncoder() { context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } + + @Test + public void testTypedAggregation() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + GroupedDataset> grouped = ds.groupBy( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + + Dataset> agged = + grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), + expr("sum(_2)"), + count("*")) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Assert.assertEquals( + Arrays.asList( + new Tuple4("a", 3, 3L, 2L), + new Tuple4("b", 3, 3L, 1L)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator, Integer, Integer> { + + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2 t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 46f9f077fe7f2..9377589790011 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator /** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero @@ -37,7 +37,7 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ override def finish(reduction: N): N = reduction } -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { override def zero: (Long, Long) = (0, 0) override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { @@ -51,8 +51,7 @@ object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 } -object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] - with Serializable { +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { override def zero: (Long, Long) = (0, 0) From ea6f53e48a911b49dc175ccaac8c80e0a1d97a09 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 16 Nov 2015 16:57:50 -0800 Subject: [PATCH 044/149] [SPARKR][HOTFIX] Disable flaky SparkR package build test See https://github.com/apache/spark/pull/9390#issuecomment-157160063 and https://gist.github.com/shivaram/3a2fecce60768a603dac for more information Author: Shivaram Venkataraman Closes #9744 from shivaram/sparkr-package-test-disable. --- .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 42e748ec6d528..d494b0caab85f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -369,7 +369,9 @@ class SparkSubmitSuite } } - test("correctly builds R packages included in a jar with --packages") { + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. + // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log + ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) From 30f3cfda1cce8760f15c0a48a8c47f09a5167fca Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 16 Nov 2015 16:59:16 -0800 Subject: [PATCH 045/149] [SPARK-11480][CORE][WEBUI] Wrong callsite is displayed when using AsyncRDDActions#takeAsync When we call AsyncRDDActions#takeAsync, actually another DAGScheduler#runJob is called from another thread so we cannot get proper callsite infomation. Following screenshots are before this patch applied and after. Before: 2015-11-04 1 26 40 2015-11-04 1 26 52 After: 2015-11-04 0 48 07 2015-11-04 0 48 26 Author: Kousuke Saruta Closes #9437 from sarutak/SPARK-11480. --- core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ca1eb1f4e4a9a..d5e853613b05b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -66,6 +66,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { val f = new ComplexFutureAction[Seq[T]] + val callSite = self.context.getCallSite f.run { // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which @@ -73,6 +74,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 + self.context.setCallSite(callSite) while (results.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. From 33a0ec93771ef5c3b388165b07cfab9014918d3b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 16 Nov 2015 17:00:18 -0800 Subject: [PATCH 046/149] [SPARK-11710] Document new memory management model Author: Andrew Or Closes #9676 from andrewor14/memory-management-docs. --- docs/configuration.md | 13 +++++++---- docs/tuning.md | 54 ++++++++++++++++++++++++++++--------------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index d961f43acf4ab..c496146e3ed63 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -722,17 +722,20 @@ Apart from these, the following properties are also available, and may be useful Fraction of the heap space used for execution and storage. The lower this is, the more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation - in the case of sparse, unusually large records. + in the case of sparse, unusually large records. Leaving this at the default value is + recommended. For more detail, see + this description. spark.memory.storageFraction 0.5 - T​he size of the storage region within the space set aside by - s​park.memory.fraction. This region is not statically reserved, but dynamically - allocated as cache requests come in. ​Cached data may be evicted only if total storage exceeds - this region. + Amount of storage memory immune to eviction, expressed as a fraction of the size of the + region set aside by s​park.memory.fraction. The higher this is, the less + working memory may be available to execution and tasks may spill to disk more often. + Leaving this at the default value is recommended. For more detail, see + this description. diff --git a/docs/tuning.md b/docs/tuning.md index 879340a01544f..a8fe7a4532798 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -88,9 +88,39 @@ than the "raw" data inside their fields. This is due to several reasons: but also pointers (typically 8 bytes each) to the next object in the list. * Collections of primitive types often store them as "boxed" objects such as `java.lang.Integer`. -This section will discuss how to determine the memory usage of your objects, and how to improve -it -- either by changing your data structures, or by storing data in a serialized format. -We will then cover tuning Spark's cache size and the Java garbage collector. +This section will start with an overview of memory management in Spark, then discuss specific +strategies the user can take to make more efficient use of memory in his/her application. In +particular, we will describe how to determine the memory usage of your objects, and how to +improve it -- either by changing your data structures, or by storing data in a serialized +format. We will then cover tuning Spark's cache size and the Java garbage collector. + +## Memory Management Overview + +Memory usage in Spark largely falls under one of two categories: execution and storage. +Execution memory refers to that used for computation in shuffles, joins, sorts and aggregations, +while storage memory refers to that used for caching and propagating internal data across the +cluster. In Spark, execution and storage share a unified region (M). When no execution memory is +used, storage can acquire all the available memory and vice versa. Execution may evict storage +if necessary, but only until total storage memory usage falls under a certain threshold (R). +In other words, `R` describes a subregion within `M` where cached blocks are never evicted. +Storage may not evict execution due to complexities in implementation. + +This design ensures several desirable properties. First, applications that do not use caching +can use the entire space for execution, obviating unnecessary disk spills. Second, applications +that do use caching can reserve a minimum storage space (R) where their data blocks are immune +to being evicted. Lastly, this approach provides reasonable out-of-the-box performance for a +variety of workloads without requiring user expertise of how memory is divided internally. + +Although there are two relevant configurations, the typical user should not need to adjust them +as the default values are applicable to most workloads: + +* `spark.memory.fraction` expresses the size of `M` as a fraction of the total JVM heap space +(default 0.75). The rest of the space (25%) is reserved for user data structures, internal +metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually +large records. +* `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5). +`R` is the storage space within `M` where cached blocks immune to being evicted by execution. + ## Determining Memory Consumption @@ -151,18 +181,6 @@ time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+ each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in their work directories), *not* on your driver program. -**Cache Size Tuning** - -One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. -By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to -cache RDDs. This means that 40% of memory is available for any objects created during task execution. - -In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of -memory, lowering this value will help reduce the memory consumption. To change this to, say, 50%, you can call -`conf.set("spark.storage.memoryFraction", "0.5")` on your SparkConf. Combined with the use of serialized caching, -using a smaller cache should be sufficient to mitigate most of the garbage collection problems. -In case you are interested in further tuning the Java GC, continue reading below. - **Advanced GC Tuning** To further tune garbage collection, we first need to understand some basic information about memory management in the JVM: @@ -183,9 +201,9 @@ temporary objects created during task execution. Some steps which may be useful * Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for before a task completes, it means that there isn't enough memory available for executing tasks. -* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching. - This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow - down task execution! +* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of + memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer + objects than to slow down task execution! * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden From bd10eb81c98e5e9df453f721943a3e82d9f74ae4 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 16 Nov 2015 17:02:21 -0800 Subject: [PATCH 047/149] [EXAMPLE][MINOR] Add missing awaitTermination in click stream example Author: jerryshao Closes #9730 from jerryshao/clickstream-fix. --- .../spark/examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index ec7d39da8b2e9..4ef238606f82e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.examples.streaming.StreamingExamples // scalastyle:off @@ -88,7 +87,7 @@ object PageViewStream { // An external dataset we want to join to this stream val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + Map(1 -> "Patrick Wendell", 2 -> "Reynold Xin", 3 -> "Matei Zaharia").toSeq) metric match { case "pageCounts" => pageCounts.print() @@ -106,6 +105,7 @@ object PageViewStream { } ssc.start() + ssc.awaitTermination() } } // scalastyle:on println From 1c5475f1401d2233f4c61f213d1e2c2ee9673067 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 16 Nov 2015 17:12:39 -0800 Subject: [PATCH 048/149] [SPARK-11612][ML] Pipeline and PipelineModel persistence Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley Closes #9674 from jkbradley/pipeline-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 175 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 4 +- .../org/apache/spark/ml/PipelineSuite.scala | 120 +++++++++++- .../spark/ml/util/DefaultReadWriteTest.scala | 25 ++- 4 files changed, 306 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e59401c5cfb..25f0c696f42be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,12 +22,19 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Reader +import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + override def write: Writer = new Pipeline.PipelineWriter(this) +} + +object Pipeline extends Readable[Pipeline] { + + override def read: Reader[Pipeline] = new PipelineReader + + override def load(path: String): Pipeline = read.load(path) + + private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + + SharedReadWrite.validateStages(instance.getStages) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private[ml] class PipelineReader extends Reader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.Pipeline" + + override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } + } + + /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: Writable => // good + case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${other.uid} of type ${other.getClass}") + } + } + + /** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UID + */ + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + // Copied and edited from DefaultParamsWriter.saveMetadata + // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication + val uid = instance.uid + val cls = instance.getClass.getName + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) + } + } + + /** + * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * @return (UID, list of stages) + */ + def load( + expectedClassName: String, + sc: SparkContext, + path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = metadata.params match { + case JObject(pairs) => + if (pairs.length != 1) { + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException( + s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") + } + pairs.head match { + case ("stageUids", jsonValue) => + jsonValue.extract[Seq[String]].toArray + case (paramName, jsonValue) => + // Should not happen unless file is corrupted or we have a bug. + throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + + s" in metadata: ${metadata.metadataStr}") + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + } + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) + val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) + val cls = Utils.classForName(stageMetadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + } + (metadata.uid, stages) + } + + /** Get path for saving the given stage. */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } + } } /** @@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Logging { + extends Model[PipelineModel] with Writable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -200,4 +332,39 @@ class PipelineModel private[ml] ( override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + + override def write: Writer = new PipelineModel.PipelineModelWriter(this) +} + +object PipelineModel extends Readable[PipelineModel] { + + import Pipeline.SharedReadWrite + + override def read: Reader[PipelineModel] = new PipelineModelReader + + override def load(path: String): PipelineModel = read.load(path) + + private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + + SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + + override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) + } + + private[ml] class PipelineModelReader extends Reader[PipelineModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.PipelineModel" + + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") + } + new PipelineModel(uid, transformers) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ca896ed6106c4..3169c9e9af5be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -164,6 +164,8 @@ trait Readable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + * + * Note: Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) @@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap + * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1f2c9b75b617b..484026b1ba9ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,19 +17,25 @@ package org.apache.spark.ml +import java.io.File + import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.{FileSystem, Path} import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType -class PipelineSuite extends SparkFunSuite { +class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { abstract class MyModel extends Model[MyModel] @@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite { assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.stages === stages) } + + test("Pipeline read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = new Pipeline().setStages(Array(writableStage)) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.getStages.length === 1) + assert(pipeline2.getStages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("Pipeline read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) + withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } + + test("PipelineModel read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = + new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer])) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.stages.length === 1) + assert(pipeline2.stages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + + val path = new File(tempDir, pipeline.uid).getPath + val stagesDir = new Path(path, "stages").toString + val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) + assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), + s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + + s" to be saved to path: $expectedStagePath") + } + + test("PipelineModel read/write: getStagePath") { + val stageUid = "myStage" + val stagesDir = new Path("pipeline", "stages").toString + def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { + val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir) + val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString + assert(path === expected) + } + testStage(0, 1, "0") + testStage(0, 9, "0") + testStage(0, 10, "00") + testStage(1, 10, "01") + testStage(12, 999, "012") + } + + test("PipelineModel read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = + new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer])) + withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } +} + + +/** Used to test [[Pipeline]] with [[Writable]] stages */ +class WritableStage(override val uid: String) extends Transformer with Writable { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + def getIntParam: Int = $(intParam) + + def setIntParam(value: Int): this.type = set(intParam, value) + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema +} + +object WritableStage extends Readable[WritableStage] { + + override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + + override def load(path: String): WritableStage = read.load(path) +} + +/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +class UnWritableStage(override val uid: String) extends Transformer { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) + + override def transform(dataset: DataFrame): DataFrame = dataset + + override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index cac4bd9aa3ab8..c37f0503f1332 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading + * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { + def testDefaultReadWrite[T <: Params with Writable]( + instance: T, + testParams: Boolean = true): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) - instance.params.foreach { p => - if (instance.isDefined(p)) { - (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { - case (Array(values), Array(newValues)) => - assert(values === newValues, s"Values do not match on param ${p.name}.") - case (value, newValue) => - assert(value === newValue, s"Values do not match on param ${p.name}.") + if (testParams) { + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } - } else { - assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } From 540bf58f18328c68107d6c616ffd70f3a4640054 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 16 Nov 2015 17:28:11 -0800 Subject: [PATCH 049/149] [SPARK-11617][NETWORK] Fix leak in TransportFrameDecoder. The code was using the wrong API to add data to the internal composite buffer, causing buffers to leak in certain situations. Use the right API and enhance the tests to catch memory leaks. Also, avoid reusing the composite buffers when downstream handlers keep references to them; this seems to cause a few different issues even though the ref counting code seems to be correct, so instead pay the cost of copying a few bytes when that situation happens. Author: Marcelo Vanzin Closes #9619 from vanzin/SPARK-11617. --- .../network/util/TransportFrameDecoder.java | 47 ++++-- .../util/TransportFrameDecoderSuite.java | 145 +++++++++++++++--- 2 files changed, 151 insertions(+), 41 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 272ea84e6180d..5889562dd9705 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -56,32 +56,43 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception buffer = in.alloc().compositeBuffer(); } - buffer.writeBytes(in); + buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); while (buffer.isReadable()) { - feedInterceptor(); - if (interceptor != null) { - continue; - } + discardReadBytes(); + if (!feedInterceptor()) { + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } - ByteBuf frame = decodeNext(); - if (frame != null) { ctx.fireChannelRead(frame); - } else { - break; } } - // We can't discard read sub-buffers if there are other references to the buffer (e.g. - // through slices used for framing). This assumes that code that retains references - // will call retain() from the thread that called "fireChannelRead()" above, otherwise - // ref counting will go awry. - if (buffer != null && buffer.refCnt() == 1) { + discardReadBytes(); + } + + private void discardReadBytes() { + // If the buffer's been retained by downstream code, then make a copy of the remaining + // bytes into a new buffer. Otherwise, just discard stale components. + if (buffer.refCnt() > 1) { + CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + + if (buffer.readableBytes() > 0) { + ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); + spillBuf.writeBytes(buffer); + newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + } + + buffer.release(); + buffer = newBuffer; + } else { buffer.discardReadComponents(); } } - protected ByteBuf decodeNext() throws Exception { + private ByteBuf decodeNext() throws Exception { if (buffer.readableBytes() < LENGTH_SIZE) { return null; } @@ -127,10 +138,14 @@ public void setInterceptor(Interceptor interceptor) { this.interceptor = interceptor; } - private void feedInterceptor() throws Exception { + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor() throws Exception { if (interceptor != null && !interceptor.handle(buffer)) { interceptor = null; } + return interceptor != null; } public static interface Interceptor { diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index ca74f0a00cf9d..19475c21ffce9 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -18,41 +18,36 @@ package org.apache.spark.network.util; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class TransportFrameDecoderSuite { + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + @Test public void testFrameDecoding() throws Exception { - Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - - final int frameCount = 100; - ByteBuf data = Unpooled.buffer(); - try { - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - - while (data.isReadable()) { - int size = rnd.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } finally { - data.release(); - } + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); } @Test @@ -60,7 +55,7 @@ public void testInterception() throws Exception { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mockChannelHandlerContext(); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -70,16 +65,56 @@ public void testInterception() throws Exception { decoder.setInterceptor(interceptor); for (int i = 0; i < interceptedReads; i++) { decoder.channelRead(ctx, dataBuf); - dataBuf.release(); + assertEquals(0, dataBuf.refCnt()); dataBuf = Unpooled.wrappedBuffer(data); } decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); } finally { - len.release(); - dataBuf.release(); + release(len); + release(dataBuf); + } + } + + @Test + public void testRetainedFrames() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + + final AtomicInteger count = new AtomicInteger(); + final List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } } } @@ -100,6 +135,47 @@ public void testLargeFrame() throws Exception { testInvalidFrame(Integer.MAX_VALUE + 9); } + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } + return data; + } + + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + private void testInvalidFrame(long size) throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); @@ -111,6 +187,25 @@ private void testInvalidFrame(long size) throws Exception { } } + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { private int remainingReads; From fbad920dbfd6f389dea852cdc159cacb0f4f6997 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 16 Nov 2015 20:47:46 -0800 Subject: [PATCH 050/149] [SPARK-11768][SPARK-9196][SQL] Support now function in SQL (alias for current_timestamp). This patch adds an alias for current_timestamp (now function). Also fixes SPARK-9196 to re-enable the test case for current_timestamp. Author: Reynold Xin Closes #9753 from rxin/SPARK-11768. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../apache/spark/sql/DateFunctionsSuite.scala | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a8f4d257acd0a..f9c04d7ec0b0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -244,6 +244,7 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[CurrentTimestamp]("now"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 1266d534cc5b3..241cbd0115070 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -38,15 +38,21 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } - // This is a bad test. SPARK-9196 will fix it and re-enable it. - ignore("function current_timestamp") { + test("function current_timestamp and now") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value - checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) + + // Current timestamp should return the current timestamp ... + val before = System.currentTimeMillis + val got = sql("SELECT CURRENT_TIMESTAMP()").collect().head.getTimestamp(0).getTime + val after = System.currentTimeMillis + assert(got >= before && got <= after) + + // Now alias + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") From 75d202073143d5a7f943890d8682b5b0cf9e3092 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 17 Nov 2015 14:35:00 +0800 Subject: [PATCH 051/149] [SPARK-11694][FOLLOW-UP] Clean up imports, use a common function for metadata and add a test for FIXED_LEN_BYTE_ARRAY As discussed https://github.com/apache/spark/pull/9660 https://github.com/apache/spark/pull/9060, I cleaned up unused imports, added a test for fixed-length byte array and used a common function for writing metadata for Parquet. For the test for fixed-length byte array, I have tested and checked the encoding types with [parquet-tools](https://github.com/Parquet/parquet-mr/tree/master/parquet-tools). Author: hyukjinkwon Closes #9754 from HyukjinKwon/SPARK-11694-followup. --- .../test/resources/dec-in-fixed-len.parquet | Bin 0 -> 460 bytes .../datasources/parquet/ParquetIOSuite.scala | 42 +++++++----------- 2 files changed, 15 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/test/resources/dec-in-fixed-len.parquet diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/dec-in-fixed-len.parquet new file mode 100644 index 0000000000000000000000000000000000000000..6ad37d5639511cdb430f33fa6165eb70cd9034c0 GIT binary patch literal 460 zcmZuu%SyvQ6rI#oO3~7VwT&8yPVu5U4^0 z3hK5WJW4SNaflr_N}BXrgbo(V<<*j?`)dokQEarvSr7`N-{ZFH k+I`P - val extraMetadata = Map.empty[String, String].asJava - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) - val footer = List( - new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())) - ).asJava - - ParquetFileWriter.writeMetadataFile(sparkContext.hadoopConfiguration, path, footer) - + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) val errorMessage = intercept[Throwable] { sqlContext.read.parquet(path.toString).printSchema() }.toString @@ -267,20 +259,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { |} """.stripMargin) + val expectedSparkTypes = Seq(StringType, BinaryType) + withTempPath { location => - val extraMetadata = Map.empty[String, String].asJava - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) - val footer = List( - new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())) - ).asJava - - ParquetFileWriter.writeMetadataFile(sparkContext.hadoopConfiguration, path, footer) - - val jsonDataType = sqlContext.read.parquet(path.toString).schema(0).dataType - assert(jsonDataType === StringType) - val bsonDataType = sqlContext.read.parquet(path.toString).schema(1).dataType - assert(bsonDataType === BinaryType) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) } } @@ -607,10 +593,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } - // TODO Adds test case for reading dictionary encoded decimals written as `FIXED_LEN_BYTE_ARRAY` - // The Parquet writer version Spark 1.6 and prior versions use is `PARQUET_1_0`, which doesn't - // provide dictionary encoding support for `FIXED_LEN_BYTE_ARRAY`. Should add a test here once - // we upgrade to `PARQUET_2_0`. + test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-fixed-len.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From e01865af0d5ebe11033de46c388c5c583876c187 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 16 Nov 2015 22:54:29 -0800 Subject: [PATCH 052/149] [SPARK-11447][SQL] change NullType to StringType during binaryComparison between NullType and StringType During executing PromoteStrings rule, if one side of binaryComparison is StringType and the other side is not StringType, the current code will promote(cast) the StringType to DoubleType, and if the StringType doesn't contain the numbers, it will get null value. So if it is doing <=> (NULL-safe equal) with Null, it will not filter anything, caused the problem reported by this jira. I proposal to the changes through this PR, can you review my code changes ? This problem only happen for <=>, other operators works fine. scala> val filteredDF = df.filter(df("column") > (new Column(Literal(null)))) filteredDF: org.apache.spark.sql.DataFrame = [column: string] scala> filteredDF.show +------+ |column| +------+ +------+ scala> val filteredDF = df.filter(df("column") === (new Column(Literal(null)))) filteredDF: org.apache.spark.sql.DataFrame = [column: string] scala> filteredDF.show +------+ |column| +------+ +------+ scala> df.registerTempTable("DF") scala> sqlContext.sql("select * from DF where 'column' = NULL") res27: org.apache.spark.sql.DataFrame = [column: string] scala> res27.show +------+ |column| +------+ +------+ Author: Kevin Yu Closes #9720 from kevinyu98/working_on_spark-11447. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 6 ++++++ .../org/apache/spark/sql/ColumnExpressionSuite.scala | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 92188ee54fd28..f90fc3cc12189 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -281,6 +281,12 @@ object HiveTypeCoercion { case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + // Checking NullType + case p @ BinaryComparison(left @ StringType(), right @ NullType()) => + p.makeCopy(Array(left, Literal.create(null, StringType))) + case p @ BinaryComparison(left @ NullType(), right @ StringType()) => + p.makeCopy(Array(Literal.create(null, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => p.makeCopy(Array(Cast(left, DoubleType), right)) case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 3eae3f6d85066..38c0eb589f965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -368,6 +368,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) + + val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + Row("abc") :: + Row(null) :: + Row("xyz") :: Nil), + StructType(Seq(StructField("a", StringType, true)))) + + checkAnswer( + nullData2.filter($"a" <=> null), + Row(null) :: Nil) + } test(">") { From d79d8b08ff69b30b02fe87839e695e29bfea5ace Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Nov 2015 23:16:17 -0800 Subject: [PATCH 053/149] [MINOR] [SQL] Fix randomly generated ArrayData in RowEncoderSuite The randomly generated ArrayData used for the UDT `ExamplePoint` in `RowEncoderSuite` sometimes doesn't have enough elements. In this case, this test will fail. This patch is to fix it. Author: Liang-Chi Hsieh Closes #9757 from viirya/fix-randomgenerated-udt. --- .../spark/sql/catalyst/encoders/RowEncoderSuite.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index c868ddec1bab2..46c6e0d98d349 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.encoders +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} @@ -59,7 +61,12 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def deserialize(datum: Any): ExamplePoint = { datum match { case values: ArrayData => - new ExamplePoint(values.getDouble(0), values.getDouble(1)) + if (values.numElements() > 1) { + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } else { + val random = new Random() + new ExamplePoint(random.nextDouble(), random.nextDouble()) + } } } From fa13301ae440c4c9594280f236bcca11b62fdd29 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 17 Nov 2015 18:11:08 +0800 Subject: [PATCH 054/149] [SPARK-11191][SQL][FOLLOW-UP] Cleans up unnecessary anonymous HiveFunctionRegistry According to discussion in PR #9664, the anonymous `HiveFunctionRegistry` in `HiveContext` can be removed now. Author: Cheng Lian Closes #9737 from liancheng/spark-11191.follow-up. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 10 +--------- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 7 +++++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 0c473799cc991..2004f24ad26c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,15 +454,7 @@ class HiveContext private[hive]( // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this) { - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - // Hive Registry need current database to lookup function - // TODO: the current database of executionHive should be consistent with metadataHive - executionHive.withHiveState { - super.lookupFunction(name, children) - } - } - } + new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this.executionHive) // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer // can't access the SessionState of metadataHive. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e6fe2ad5f23b6..2e8c026259efe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -43,16 +43,19 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ private[hive] class HiveFunctionRegistry( underlying: analysis.FunctionRegistry, - hiveContext: HiveContext) + executionHive: ClientWrapper) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = { - hiveContext.executionHive.withHiveState { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { FunctionRegistry.getFunctionInfo(name) } } From 7276fa9aa9d2eccb6aebd5c690ac334699142f1e Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Tue, 17 Nov 2015 14:11:34 +0000 Subject: [PATCH 055/149] [SPARK-11751] Doc describe error in the "Spark Streaming Programming Guide" page In the **[Task Launching Overheads](http://spark.apache.org/docs/latest/streaming-programming-guide.html#task-launching-overheads)** section, >Task Serialization: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. as we known **Task Serialization** is configuration by **spark.closure.serializer** parameter, but currently only the Java serializer is supported. If we set **spark.closure.serializer** to **org.apache.spark.serializer.KryoSerializer**, then this will throw a exception. Author: yangping.wu Closes #9734 from 397090770/397090770-patch-1. --- docs/streaming-programming-guide.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e9a27f446a898..96b36b7a73209 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2001,8 +2001,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task - sizes, and therefore reduce the time taken to send them to the slaves. +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. This is controlled by the ```spark.closure.serializer``` property. However, at this time, Kryo serialization cannot be enabled for closure serialization. This may be resolved in a future release. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the From 15cc36b7786e2d9a460bf565893236edd2ad993e Mon Sep 17 00:00:00 2001 From: Philipp Hoffmann Date: Tue, 17 Nov 2015 14:13:13 +0000 Subject: [PATCH 056/149] [SPARK-11779][DOCS] Fix reference to deprecated MESOS_NATIVE_LIBRARY MESOS_NATIVE_LIBRARY was renamed in favor of MESOS_NATIVE_JAVA_LIBRARY. This commit fixes the reference in the documentation. Author: Philipp Hoffmann Closes #9768 from philipphoffmann/patch-2. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ec5a44d79212b..5be208cf3461e 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -278,7 +278,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the name of the docker image that the Spark executors will run in. The selected image must have Spark installed, as well as a compatible version of the Mesos library. The installed path of Spark in the image can be specified with spark.mesos.executor.home; - the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_JAVA_LIBRARY. From 6fc2740ebb59aca1aa0ee1e93658a7e4e69de33c Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Nov 2015 10:01:33 -0800 Subject: [PATCH 057/149] [SPARK-11744][LAUNCHER] Fix print version throw exception when using pyspark shell Exception details can be seen here (https://issues.apache.org/jira/browse/SPARK-11744). Author: jerryshao Closes #9721 from jerryshao/SPARK-11744. --- .../launcher/SparkSubmitCommandBuilder.java | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 39b46e0db8cc2..312df0b269f32 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,7 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } final List sparkArgs; - private final boolean printHelp; + private final boolean printInfo; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -88,7 +88,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); - this.printHelp = false; + this.printInfo = false; } SparkSubmitCommandBuilder(List args) { @@ -108,14 +108,14 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { OptionParser parser = new OptionParser(); parser.parse(submitArgs); - this.printHelp = parser.helpRequested; + this.printInfo = parser.infoRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -311,7 +311,7 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { - boolean helpRequested = false; + boolean infoRequested = false; @Override protected boolean handle(String opt, String value) { @@ -344,7 +344,10 @@ protected boolean handle(String opt, String value) { appResource = specialClasses.get(value); } } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { - helpRequested = true; + infoRequested = true; + sparkArgs.add(opt); + } else if (opt.equals(VERSION)) { + infoRequested = true; sparkArgs.add(opt); } else { sparkArgs.add(opt); From cc567b6634c3142125526f4875795c1b1e862838 Mon Sep 17 00:00:00 2001 From: Chris Bannister Date: Tue, 17 Nov 2015 10:03:46 -0800 Subject: [PATCH 058/149] [SPARK-11695][CORE] Set s3a credentials Set s3a credentials when creating a new default hadoop configuration. Author: Chris Bannister Closes #9663 from Zariel/set-s3a-creds. --- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index d606b80c03c98..59e90564b3516 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -92,10 +92,15 @@ class SparkHadoopUtil extends Logging { // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + + hadoopConf.set("fs.s3.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) + hadoopConf.set("fs.s3a.access.key", keyId) + hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) + hadoopConf.set("fs.s3a.secret.key", accessKey) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => From 21fac5434174389e8b83a2f11341fa7c9e360bfd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 17 Nov 2015 10:17:16 -0800 Subject: [PATCH 059/149] [SPARK-11766][MLLIB] add toJson/fromJson to Vector/Vectors This is to support JSON serialization of Param[Vector] in the pipeline API. It could be used for other purposes too. The schema is the same as `VectorUDT`. jkbradley Author: Xiangrui Meng Closes #9751 from mengxr/SPARK-11766. --- .../apache/spark/mllib/linalg/Vectors.scala | 45 +++++++++++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 17 +++++++ project/MimaExcludes.scala | 4 ++ 3 files changed, 66 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index bd9badc03c345..4dcf351df43fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -24,6 +24,9 @@ import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} @@ -171,6 +174,12 @@ sealed trait Vector extends Serializable { */ @Since("1.5.0") def argmax: Int + + /** + * Converts the vector to a JSON string. + */ + @Since("1.6.0") + def toJson: String } /** @@ -339,6 +348,27 @@ object Vectors { parseNumeric(NumericParser.parse(s)) } + /** + * Parses the JSON representation of a vector into a [[Vector]]. + */ + @Since("1.6.0") + def fromJson(json: String): Vector = { + implicit val formats = DefaultFormats + val jValue = parseJson(json) + (jValue \ "type").extract[Int] match { + case 0 => // sparse + val size = (jValue \ "size").extract[Int] + val indices = (jValue \ "indices").extract[Seq[Int]].toArray + val values = (jValue \ "values").extract[Seq[Double]].toArray + sparse(size, indices, values) + case 1 => // dense + val values = (jValue \ "values").extract[Seq[Double]].toArray + dense(values) + case _ => + throw new IllegalArgumentException(s"Cannot parse $json into a vector.") + } + } + private[mllib] def parseNumeric(any: Any): Vector = { any match { case values: Array[Double] => @@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") ( maxIdx } } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") @@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") ( }.unzip new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 0) ~ + ("size" -> size) ~ + ("indices" -> indices.toSeq) ~ + ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 6508ddeba4206..f895e2a8e4afb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import org.json4s.jackson.JsonMethods.{parse => parseJson} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ @@ -374,4 +375,20 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("toJson/fromJson") { + val sv0 = Vectors.sparse(0, Array.empty, Array.empty) + val sv1 = Vectors.sparse(1, Array.empty, Array.empty) + val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0) + val dv2 = Vectors.dense(0.0, 2.0) + for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) { + val json = v.toJson + parseJson(json) // `json` should be a valid JSON string + val u = Vectors.fromJson(json) + assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.") + assert(u === v, "toJson/fromJson should preserve vector values.") + } + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 50220790d1f84..815951822c1ef 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -137,6 +137,10 @@ object MimaExcludes { ) ++ Seq ( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.status.api.v1.ApplicationInfo.this") + ) ++ Seq( + // SPARK-11766 add toJson to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toJson") ) case v if v.startsWith("1.5") => Seq( From e8833dd12c71b23a242727e86684d2d868ff84b3 Mon Sep 17 00:00:00 2001 From: mayuanwen Date: Tue, 17 Nov 2015 11:15:46 -0800 Subject: [PATCH 060/149] [SPARK-11679][SQL] Invoking method " apply(fields: java.util.List[StructField])" in "StructType" gets ClassCastException In the previous method, fields.toArray will cast java.util.List[StructField] into Array[Object] which can not cast into Array[StructField], thus when invoking this method will throw "java.lang.ClassCastException: [Ljava.lang.Object; cannot be cast to [Lorg.apache.spark.sql.types.StructField;" I directly cast java.util.List[StructField] into Array[StructField] in this patch. Author: mayuanwen Closes #9649 from jackieMaKing/Spark-11679. --- .../org/apache/spark/sql/types/StructType.scala | 3 ++- .../org/apache/spark/sql/JavaDataFrameSuite.java | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 11fce4beaf55f..9778df271ddd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -328,7 +328,8 @@ object StructType extends AbstractDataType { def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) + import scala.collection.JavaConverters._ + StructType(fields.asScala) } protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d191b50fa802e..567bdddece80e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -22,6 +22,7 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.ArrayList; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -209,6 +210,18 @@ public void testCreateDataFromFromList() { Assert.assertEquals(1, result.length); } + @Test + public void testCreateStructTypeFromList(){ + List fields1 = new ArrayList<>(); + fields1.add(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema1 = StructType$.MODULE$.apply(fields1); + Assert.assertEquals(0, schema1.fieldIndex("id")); + + List fields2 = Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema2 = StructType$.MODULE$.apply(fields2); + Assert.assertEquals(0, schema2.fieldIndex("id")); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 7b1407c7b95c43299a30e891748824c4bc47e43b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 17 Nov 2015 11:17:52 -0800 Subject: [PATCH 061/149] [SPARK-11089][SQL] Adds option for disabling multi-session in Thrift server This PR adds a new option `spark.sql.hive.thriftServer.singleSession` for disabling multi-session support in the Thrift server. Note that this option is added as a Spark configuration (retrieved from `SparkConf`) rather than Spark SQL configuration (retrieved from `SQLConf`). This is because all SQL configurations are session-ized. Since multi-session support is by default on, no JDBC connection can modify global configurations like the newly added one. Author: Cheng Lian Closes #9740 from liancheng/spark-11089.single-session-option. --- docs/sql-programming-guide.md | 14 +++++ .../thriftserver/SparkSQLSessionManager.scala | 6 ++- .../HiveThriftServer2Suites.scala | 51 ++++++++++++++++++- .../apache/spark/sql/hive/HiveContext.scala | 3 ++ 4 files changed, 72 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6e02d6564b002..e347754055e79 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2051,6 +2051,20 @@ options. # Migration Guide +## Upgrading From Spark SQL 1.5 to 1.6 + + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: + + {% highlight bash %} + ./sbin/start-thriftserver.sh \ + --conf spark.sql.hive.thriftServer.singleSession=true \ + ... + {% endhighlight %} + ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 33aaead3fbf96..af4fcdf021bd4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -66,7 +66,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - val ctx = hiveContext.newSession() + val ctx = if (hiveContext.hiveThriftServerSingleSession) { + hiveContext + } else { + hiveContext.newSession() + } ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index eb1895f263d70..1dd898aa38350 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -41,7 +41,6 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} @@ -510,6 +509,53 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } +class SingleSessionSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.binary + + override protected def extraConf: Seq[String] = + "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil + + test("test single session") { + withMultipleConnectionJdbcStatement( + { statement => + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + // Configurations and temporary functions added in this session should be visible to all + // the other sessions. + Seq( + "SET foo=bar", + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + }, + + { statement => + val rs1 = statement.executeQuery("SET foo") + + assert(rs1.next()) + assert(rs1.getString(1) === "foo") + assert(rs1.getString(2) === "bar") + + val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs2.next()) + assert(rs2.getString(1) === "Function: udtf_count2") + + assert(rs2.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs2.getString(1) + } + + assert(rs2.next()) + assert(rs2.getString(1) === "Usage: To be added.") + } + ) + } +} + class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.http @@ -600,6 +646,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + protected def extraConf: Seq[String] = Nil + protected def serverStartCommand(port: Int) = { val portConf = if (mode == ServerMode.binary) { ConfVars.HIVE_SERVER2_THRIFT_PORT @@ -635,6 +683,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false + | ${extraConf.mkString("\n")} """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2004f24ad26c6..c0bb5af7d5c85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -190,6 +190,9 @@ class HiveContext private[hive]( */ protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) + protected[hive] def hiveThriftServerSingleSession: Boolean = + sc.conf.get("spark.sql.hive.thriftServer.singleSession", "false").toBoolean + @transient protected[sql] lazy val substitutor = new VariableSubstitution() From 0158ff7737d10e68be2e289533241da96b496e89 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 17 Nov 2015 11:23:54 -0800 Subject: [PATCH 062/149] [SPARK-8658][SQL][FOLLOW-UP] AttributeReference's equals method compares all the members Based on the comment of cloud-fan in https://github.com/apache/spark/pull/9216, update the AttributeReference's hashCode function by including the hashCode of the other attributes including name, nullable and qualifiers. Here, I am not 100% sure if we should include name in the hashCode calculation, since the original hashCode calculation does not include it. marmbrus cloud-fan Please review if the changes are good. Author: gatorsmile Closes #9761 from gatorsmile/hashCodeNamedExpression. --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 5 ++++- .../expressions/SubexpressionEliminationSuite.scala | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e3daddace241d..00b7970bd16c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -212,9 +212,12 @@ case class AttributeReference( override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 - h = h * 37 + exprId.hashCode() + h = h * 37 + name.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + nullable.hashCode() h = h * 37 + metadata.hashCode() + h = h * 37 + exprId.hashCode() + h = h * 37 + qualifiers.hashCode() h } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 9de066e99d637..a61297b2c0395 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -25,14 +25,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite { val a: AttributeReference = AttributeReference("name", IntegerType)() val b1 = a.withName("name2").withExprId(id) val b2 = a.withExprId(id) + val b3 = a.withQualifiers("qualifierName" :: Nil) assert(b1 != b2) assert(a != b1) assert(b1.semanticEquals(b2)) assert(!b1.semanticEquals(a)) assert(a.hashCode != b1.hashCode) - assert(b1.hashCode == b2.hashCode) + assert(b1.hashCode != b2.hashCode) assert(b1.semanticHash() == b2.semanticHash()) + assert(a != b3) + assert(a.hashCode != b3.hashCode) + assert(a.semanticEquals(b3)) } test("Expression Equivalence - basic") { From d9251496640a77568a1e9ed5045ce2dfba4b437b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 17 Nov 2015 11:29:02 -0800 Subject: [PATCH 063/149] [SPARK-10186][SQL] support postgre array type in JDBCRDD Add ARRAY support to `PostgresDialect`. Nested ARRAY is not allowed for now because it's hard to get the array dimension info. See http://stackoverflow.com/questions/16619113/how-to-get-array-base-type-in-postgres-via-jdbc Thanks for the initial work from mariusvniekerk ! Close https://github.com/apache/spark/pull/9137 Author: Wenchen Fan Closes #9662 from cloud-fan/postgre. --- .../sql/jdbc/PostgresIntegrationSuite.scala | 44 +++++++---- .../execution/datasources/jdbc/JDBCRDD.scala | 76 +++++++++++++----- .../datasources/jdbc/JdbcUtils.scala | 77 ++++++++++--------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../spark/sql/jdbc/PostgresDialect.scala | 43 ++++++++--- 5 files changed, 157 insertions(+), 85 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 164a7f396280c..2e18d0a2baa1c 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.Connection import java.util.Properties +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Literal, If} import org.apache.spark.tags.DockerTest @DockerTest @@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " - + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " + + "c10 integer[], c11 text[])").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " - + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " + + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate() } test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) val rows = df.collect() assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 10) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Double")) - assert(types(3).equals("class java.lang.Long")) - assert(types(4).equals("class java.lang.Boolean")) - assert(types(5).equals("class [B")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class java.lang.Boolean")) - assert(types(8).equals("class java.lang.String")) - assert(types(9).equals("class java.lang.String")) + val types = rows(0).toSeq.map(x => x.getClass) + assert(types.length == 12) + assert(classOf[String].isAssignableFrom(types(0))) + assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) + assert(classOf[java.lang.Double].isAssignableFrom(types(2))) + assert(classOf[java.lang.Long].isAssignableFrom(types(3))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(4))) + assert(classOf[Array[Byte]].isAssignableFrom(types(5))) + assert(classOf[Array[Byte]].isAssignableFrom(types(6))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(7))) + assert(classOf[String].isAssignableFrom(types(8))) + assert(classOf[String].isAssignableFrom(types(9))) + assert(classOf[Seq[Int]].isAssignableFrom(types(10))) + assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getBoolean(7) == true) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") + assert(rows(0).getSeq(10) == Seq(1, 2)) + assert(rows(0).getSeq(11) == Seq("a", null, "b")) } test("Basic write test") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) - df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) // Test only that it doesn't crash. + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test write null values. + df.select(df.queryExecution.analyzed.output.map { a => + Column(If(Literal(true), Literal(null), a)).as(a.name) + }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 018a009fbda6d..89c850ce238d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -324,25 +324,27 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** * Maps a StructType to a type tag list. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + def getConversions(schema: StructType): Array[JDBCConversion] = + schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + + private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } /** @@ -420,16 +422,44 @@ private[sql] class JDBCRDD( mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + case ArrayConversion(elementConversion) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion match { + case TimestampConversion => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + case StringConversion => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + case DateConversion => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + case DecimalConversion(p, s) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) + } + case BinaryLongConversion => + throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case _: ArrayConversion => + throw new IllegalArgumentException("Nested arrays unsupported") + case _ => array.asInstanceOf[Array[Any]] + } + mutableRow.update(i, new GenericArrayData(data)) + } else { + mutableRow.update(i, null) + } } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 @@ -488,4 +518,12 @@ private[sql] class JDBCRDD( nextValue } } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e212..32d28e59377a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -72,6 +72,35 @@ object JdbcUtils extends Logging { conn.prepareStatement(sql.toString()) } + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } + } + + private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + } + /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction in order to avoid repeatedly inserting @@ -92,7 +121,8 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], - batchSize: Int): Iterator[Byte] = { + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false try { @@ -121,6 +151,11 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(et, _) => + val array = conn.createArrayOf( + getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + row.getSeq[AnyRef](i).toArray) + stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -169,23 +204,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) + val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -202,23 +221,7 @@ object JdbcUtils extends Logging { properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) + getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema @@ -226,7 +229,7 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 14bfea4e3e287..b3b2cb6178c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -abstract class JdbcDialect { +abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. * @param url the jdbc url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e701a7fcd9e16..ed3faa1268635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect { override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) + Some(BinaryType) + } else if (sqlType == Types.OTHER) { + toCatalystType(typeName).filter(_ == StringType) + } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { + toCatalystType(typeName.drop(1)).map(ArrayType(_)) } else None } + // TODO: support more type names. + private def toCatalystType(typeName: String): Option[DataType] = typeName match { + case "bool" => Some(BooleanType) + case "bit" => Some(BinaryType) + case "int2" => Some(ShortType) + case "int4" => Some(IntegerType) + case "int8" | "oid" => Some(LongType) + case "float4" => Some(FloatType) + case "money" | "float8" => Some(DoubleType) + case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => + Some(StringType) + case "bytea" => Some(BinaryType) + case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) + case "date" => Some(DateType) + case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case StringType => Some(JdbcType("TEXT", Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + getJDBCType(et).map(_.databaseTypeDefinition) + .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) + .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case _ => None } From d98d1cb000c8c4e391d73ae86efd09f15e5d165c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 12:43:56 -0800 Subject: [PATCH 064/149] [SPARK-11769][ML] Add save, load to all basic Transformers This excludes Estimators and ones which include Vector and other non-basic types for Params or data. This adds: * Bucketizer * DCT * HashingTF * Interaction * NGram * Normalizer * OneHotEncoder * PolynomialExpansion * QuantileDiscretizer * RFormula * SQLTransformer * StopWordsRemover * StringIndexer * Tokenizer * VectorAssembler * VectorSlicer CC: mengxr Author: Joseph K. Bradley Closes #9755 from jkbradley/transformer-io. --- .../apache/spark/ml/feature/Binarizer.scala | 8 ++++- .../apache/spark/ml/feature/Bucketizer.scala | 22 ++++++++---- .../org/apache/spark/ml/feature/DCT.scala | 19 ++++++++-- .../apache/spark/ml/feature/HashingTF.scala | 20 +++++++++-- .../apache/spark/ml/feature/Interaction.scala | 29 ++++++++++++--- .../org/apache/spark/ml/feature/NGram.scala | 19 ++++++++-- .../apache/spark/ml/feature/Normalizer.scala | 20 +++++++++-- .../spark/ml/feature/OneHotEncoder.scala | 19 ++++++++-- .../ml/feature/PolynomialExpansion.scala | 20 ++++++++--- .../ml/feature/QuantileDiscretizer.scala | 22 ++++++++---- .../spark/ml/feature/SQLTransformer.scala | 27 ++++++++++++-- .../spark/ml/feature/StopWordsRemover.scala | 19 ++++++++-- .../spark/ml/feature/StringIndexer.scala | 22 +++++++++--- .../apache/spark/ml/feature/Tokenizer.scala | 35 ++++++++++++++++--- .../spark/ml/feature/VectorAssembler.scala | 18 +++++++--- .../spark/ml/feature/VectorSlicer.scala | 22 ++++++++---- .../spark/ml/feature/BinarizerSuite.scala | 8 ++--- .../spark/ml/feature/BucketizerSuite.scala | 12 +++++-- .../apache/spark/ml/feature/DCTSuite.scala | 11 +++++- .../spark/ml/feature/HashingTFSuite.scala | 11 +++++- .../spark/ml/feature/InteractionSuite.scala | 10 +++++- .../apache/spark/ml/feature/NGramSuite.scala | 11 +++++- .../spark/ml/feature/NormalizerSuite.scala | 11 +++++- .../spark/ml/feature/OneHotEncoderSuite.scala | 12 ++++++- .../ml/feature/PolynomialExpansionSuite.scala | 12 ++++++- .../ml/feature/QuantileDiscretizerSuite.scala | 13 ++++++- .../ml/feature/SQLTransformerSuite.scala | 10 +++++- .../ml/feature/StopWordsRemoverSuite.scala | 14 +++++++- .../spark/ml/feature/StringIndexerSuite.scala | 13 +++++-- .../spark/ml/feature/TokenizerSuite.scala | 25 +++++++++++-- .../ml/feature/VectorAssemblerSuite.scala | 11 +++++- .../spark/ml/feature/VectorSlicerSuite.scala | 12 ++++++- 32 files changed, 453 insertions(+), 84 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e5c25574d4b11..e2be6547d8f00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -87,10 +87,16 @@ final class Binarizer(override val uid: String) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + @Since("1.6.0") override def write: Writer = new DefaultParamsWriter(this) } +@Since("1.6.0") object Binarizer extends Readable[Binarizer] { + @Since("1.6.0") override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] + + @Since("1.6.0") + override def load(path: String): Binarizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6fdf25b015b0b..7095fbd70aa07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,11 +93,15 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object Bucketizer { +object Bucketizer extends Readable[Bucketizer] { + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ - def checkSplits(splits: Array[Double]): Boolean = { + private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false } else { @@ -115,7 +119,7 @@ private[feature] object Bucketizer { * Binary searching in several buckets to place each data point. * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { if (feature == splits.last) { splits.length - 2 } else { @@ -134,4 +138,10 @@ private[feature] object Bucketizer { } } } + + @Since("1.6.0") + override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] + + @Since("1.6.0") + override def load(path: String): Bucketizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 228347635c92b..6ea5a616173ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.types.DataType @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] { + extends UnaryTransformer[Vector, Vector, DCT] with Writable { def this() = this(Identifiable.randomUID("dct")) @@ -69,4 +69,17 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object DCT extends Readable[DCT] { + + @Since("1.6.0") + override def read: Reader[DCT] = new DefaultParamsReader[DCT] + + @Since("1.6.0") + override def load(path: String): DCT = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 319d23e46cef4..6d2ea675f5617 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * Maps a sequence of terms to their term frequencies using the hashing trick. */ @Experimental -class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { +class HashingTF(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -76,4 +77,17 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object HashingTF extends Readable[HashingTF] { + + @Since("1.6.0") + override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] + + @Since("1.6.0") + override def load(path: String): HashingTF = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 37f7862476cfe..9df6b311cc9da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} @@ -42,24 +42,30 @@ import org.apache.spark.sql.types._ * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. */ +@Since("1.6.0") @Experimental -class Interaction(override val uid: String) extends Transformer - with HasInputCols with HasOutputCol { +class Interaction @Since("1.6.0") (override val uid: String) extends Transformer + with HasInputCols with HasOutputCol with Writable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) /** @group setParam */ + @Since("1.6.0") def setInputCols(values: Array[String]): this.type = set(inputCols, values) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) // optimistic schema; does not contain any ML attributes + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { validateParams() StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } + @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { validateParams() val inputFeatures = $(inputCols).map(c => dataset.schema(c)) @@ -208,14 +214,29 @@ class Interaction(override val uid: String) extends Transformer } } + @Since("1.6.0") override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + @Since("1.6.0") override def validateParams(): Unit = { require(get(inputCols).isDefined, "Input cols must be defined first.") require(get(outputCol).isDefined, "Output col must be defined first.") require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Interaction extends Readable[Interaction] { + + @Since("1.6.0") + override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] + + @Since("1.6.0") + override def load(path: String): Interaction = read.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 8de10eb51f923..4a17acd95199f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,4 +66,17 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object NGram extends Readable[NGram] { + + @Since("1.6.0") + override def read: Reader[NGram] = new DefaultParamsReader[NGram] + + @Since("1.6.0") + override def load(path: String): NGram = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 8282e5ffa17f7..9df6a091d5058 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType * Normalize a vector to have unit norm using the given p-norm. */ @Experimental -class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { +class Normalizer(override val uid: String) + extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { def this() = this(Identifiable.randomUID("normalizer")) @@ -55,4 +56,17 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect } override protected def outputDataType: DataType = new VectorUDT() + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Normalizer extends Readable[Normalizer] { + + @Since("1.6.0") + override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] + + @Since("1.6.0") + override def load(path: String): Normalizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 9c60d4084ec46..4e2adfaafa21e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { + with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,4 +165,17 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object OneHotEncoder extends Readable[OneHotEncoder] { + + @Since("1.6.0") + override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] + + @Since("1.6.0") + override def load(path: String): OneHotEncoder = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d85e468562d4a..49415398325fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { def this() = this(Identifiable.randomUID("poly")) @@ -63,6 +63,9 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } /** @@ -77,7 +80,8 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -private[feature] object PolynomialExpansion { +@Since("1.6.0") +object PolynomialExpansion extends Readable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -169,11 +173,17 @@ private[feature] object PolynomialExpansion { new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } - def expand(v: Vector, degree: Int): Vector = { + private[feature] def expand(v: Vector, degree: Int): Vector = { v match { case dv: DenseVector => expand(dv, degree) case sv: SparseVector => expand(sv, degree) case _ => throw new IllegalArgumentException } } + + @Since("1.6.0") + override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] + + @Since("1.6.0") + override def load(path: String): PolynomialExpansion = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 46b836da9cfde..2da5c966d2967 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,17 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object QuantileDiscretizer extends Logging { +@Since("1.6.0") +object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ - def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") @@ -111,6 +115,7 @@ private[feature] object QuantileDiscretizer extends Logging { /** * Compute split points with respect to the sample distribution. */ + private[feature] def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => m + ((x, m.getOrElse(x, 0) + 1)) @@ -150,7 +155,7 @@ private[feature] object QuantileDiscretizer extends Logging { * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as * needed, and adding a default split value of 0 if no good candidates are found. */ - def getSplits(candidates: Array[Double]): Array[Double] = { + private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { val effectiveValues = if (candidates.size != 0) { if (candidates.head == Double.NegativeInfinity && candidates.last == Double.PositiveInfinity) { @@ -172,5 +177,10 @@ private[feature] object QuantileDiscretizer extends Logging { Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) } } -} + @Since("1.6.0") + override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] + + @Since("1.6.0") + override def load(path: String): QuantileDiscretizer = read.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95e4305638730..c115064ff301a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.{SQLContext, DataFrame, Row} import org.apache.spark.sql.types.StructType @@ -32,24 +32,30 @@ import org.apache.spark.sql.types.StructType * where '__THIS__' represents the underlying table of the input dataset. */ @Experimental -class SQLTransformer (override val uid: String) extends Transformer { +@Since("1.6.0") +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) /** * SQL statement parameter. The statement is provided in string form. * @group param */ + @Since("1.6.0") final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") /** @group setParam */ + @Since("1.6.0") def setStatement(value: String): this.type = set(statement, value) /** @group getParam */ + @Since("1.6.0") def getStatement: String = $(statement) private val tableIdentifier: String = "__THIS__" + @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) @@ -58,6 +64,7 @@ class SQLTransformer (override val uid: String) extends Transformer { outputDF } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) @@ -68,5 +75,19 @@ class SQLTransformer (override val uid: String) extends Transformer { outputSchema } + @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object SQLTransformer extends Readable[SQLTransformer] { + + @Since("1.6.0") + override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] + + @Since("1.6.0") + override def load(path: String): SQLTransformer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 2a79582625e9a..f1146988dcc7c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,4 +154,17 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StopWordsRemover extends Readable[StopWordsRemover] { + + @Since("1.6.0") + override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] + + @Since("1.6.0") + override def load(path: String): StopWordsRemover = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 486274cd75a14..f782a272d11db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,13 +18,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -188,9 +188,8 @@ class StringIndexerModel ( * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class IndexToString private[ml] ( - override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { +class IndexToString private[ml] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -257,4 +256,17 @@ class IndexToString private[ml] ( override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IndexToString extends Readable[IndexToString] { + + @Since("1.6.0") + override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] + + @Since("1.6.0") + override def load(path: String): IndexToString = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 1b82b40caac18..0e4445d1e2fa7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * @see [[RegexTokenizer]] */ @Experimental -class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { +class Tokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { def this() = this(Identifiable.randomUID("tok")) @@ -45,6 +46,19 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Tokenizer extends Readable[Tokenizer] { + + @Since("1.6.0") + override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] + + @Since("1.6.0") + override def load(path: String): Tokenizer = read.load(path) } /** @@ -56,7 +70,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { def this() = this(Identifiable.randomUID("regexTok")) @@ -131,4 +145,17 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object RegexTokenizer extends Readable[RegexTokenizer] { + + @Since("1.6.0") + override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] + + @Since("1.6.0") + override def load(path: String): RegexTokenizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 086917fa680f8..7e54205292ca2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol { + extends Transformer with HasInputCols with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,9 +120,19 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private object VectorAssembler { +@Since("1.6.0") +object VectorAssembler extends Readable[VectorAssembler] { + + @Since("1.6.0") + override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] + + @Since("1.6.0") + override def load(path: String): VectorAssembler = read.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index fb3387d4aa9be..911582b55b574 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,12 +151,16 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object VectorSlicer { +@Since("1.6.0") +object VectorSlicer extends Readable[VectorSlicer] { /** Return true if given feature indices are valid */ - def validIndices(indices: Array[Int]): Boolean = { + private[feature] def validIndices(indices: Array[Int]): Boolean = { if (indices.isEmpty) { true } else { @@ -165,7 +169,13 @@ private[feature] object VectorSlicer { } /** Return true if given feature names are valid */ - def validNames(names: Array[String]): Boolean = { + private[feature] def validNames(names: Array[String]): Boolean = { names.forall(_.nonEmpty) && names.length == names.distinct.length } + + @Since("1.6.0") + override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] + + @Since("1.6.0") + override def load(path: String): VectorSlicer = read.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9dfa1439cc303..6d2d8fe714444 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -69,10 +69,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("read/write") { - val binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") + val t = new Binarizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") .setThreshold(0.1) - testDefaultReadWrite(binarizer) + testDefaultReadWrite(t) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0eba34fda6228..9ea7d431763a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,13 +21,13 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Bucketizer) @@ -112,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } + + test("read/write") { + val t = new Bucketizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setSplits(Array(0.1, 0.8, 0.9)) + testDefaultReadWrite(t) + } } private object BucketizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 37ed2367c33f7..0f2aafebafe67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -22,6 +22,7 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -29,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { testDCT(data, inverse) } + test("read/write") { + val t = new DCT() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setInverse(true) + testDefaultReadWrite(t) + } + private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 4157b84b29d01..0dcd0f49465ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new HashingTF) @@ -50,4 +51,12 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) assert(features ~== expected absTol 1e-14) } + + test("read/write") { + val t = new HashingTF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumFeatures(10) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 2beb62ca08233..932d331b472b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite @@ -26,7 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -162,4 +163,11 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) assert(attrs === expectedAttrs) } + + test("read/write") { + val t = new Interaction() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index ab97e3dbc6ee0..58fda29aa1e69 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.NGramSuite._ test("default behavior yields bigram features") { @@ -79,6 +80,14 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { ))) testNGram(nGram, dataset) } + + test("read/write") { + val t = new NGram() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setN(3) + testDefaultReadWrite(t) + } } object NGramSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9f03470b7f328..de3d438ce83be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @@ -104,6 +105,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } + + test("read/write") { + val t = new Normalizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setP(3.0) + testDefaultReadWrite(t) + } } private object NormalizerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 321eeb843941c..76d12050f9677 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col -class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -101,4 +103,12 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } + + test("read/write") { + val t = new OneHotEncoder() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDropLast(false) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 29eebd8960ebc..70892dc57170a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -21,12 +21,14 @@ import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PolynomialExpansion) @@ -98,5 +100,13 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext throw new TestFailedException("Unmatched data types after polynomial expansion", 0) } } + + test("read/write") { + val t = new PolynomialExpansion() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDegree(3) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b2bdd8935f903..3a4f6d235aa6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.{SparkContext, SparkFunSuite} -class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class QuantileDiscretizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ test("Test quantile discretizer") { @@ -67,6 +70,14 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") } } + + test("read/write") { + val t = new QuantileDiscretizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumBuckets(6) + testDefaultReadWrite(t) + } } private object QuantileDiscretizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index d19052881ae45..553e0b870216c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { +class SQLTransformerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new SQLTransformer()) @@ -41,4 +43,10 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) } + + test("read/write") { + val t = new SQLTransformer() + .setStatement("select * from __THIS__") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index e0d433f566c25..fb217e0c1de93 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite { } } -class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { +class StopWordsRemoverSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import StopWordsRemoverSuite._ test("StopWordsRemover default") { @@ -77,4 +80,13 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { testStopWordsRemover(remover, dataSet) } + + test("read/write") { + val t = new StopWordsRemover() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setStopWords(Array("the", "a")) + .setCaseSensitive(true) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index ddcdb5f4212be..be37bfb438833 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleTy import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class StringIndexerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new StringIndexer) @@ -173,4 +174,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val outSchema = idxToStr.transformSchema(inSchema) assert(outSchema("output").dataType === StringType) } + + test("read/write") { + val t = new IndexToString() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setLabels(Array("a", "b", "c")) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index a02992a2407b3..36e8e5d868389 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -21,20 +21,30 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite { +class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) } + + test("read/write") { + val t = new Tokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } -class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("params") { @@ -81,6 +91,17 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { )) testRegexTokenizer(tokenizer, dataset) } + + test("read/write") { + val t = new RegexTokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTokenLength(2) + .setGaps(false) + .setPattern("hi") + .setToLowercase(false) + testDefaultReadWrite(t) + } } object RegexTokenizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index bb4d5b983e0d4..fb21ab6b9bf2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new VectorAssembler) @@ -101,4 +103,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) } + + test("read/write") { + val t = new VectorAssembler() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index a6c2fba8360dd..74706a23e0936 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { val slicer = new VectorSlicer @@ -106,4 +107,13 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) validateResults(vectorSlicer.transform(df)) } + + test("read/write") { + val t = new VectorSlicer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setIndices(Array(1, 3)) + .setNames(Array("a", "d")) + testDefaultReadWrite(t) + } } From 5aca6ad00c9d7fa43c725b8da4a10114a3a77421 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 17 Nov 2015 12:50:01 -0800 Subject: [PATCH 065/149] [SPARK-11767] [SQL] limit the size of caced batch Currently the size of cached batch in only controlled by `batchSize` (default value is 10000), which does not work well with the size of serialized columns (for example, complex types). The memory used to build the batch is not accounted, it's easy to OOM (especially after unified memory management). This PR introduce a hard limit as 4M for total columns (up to 50 columns of uncompressed primitive columns). This also change the way to grow buffer, double it each time, then trim it once finished. cc liancheng Author: Davies Liu Closes #9760 from davies/cache_limit. --- .../apache/spark/sql/columnar/ColumnBuilder.scala | 12 ++++++++++-- .../org/apache/spark/sql/columnar/ColumnStats.scala | 2 +- .../sql/columnar/InMemoryColumnarTableScan.scala | 6 +++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 7a7345a7e004b..599f30f2d73b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -73,6 +73,13 @@ private[sql] class BasicColumnBuilder[JvmType]( } override def build(): ByteBuffer = { + if (buffer.capacity() > buffer.position() * 1.1) { + // trim the buffer + buffer = ByteBuffer + .allocate(buffer.position()) + .order(ByteOrder.nativeOrder()) + .put(buffer.array(), 0, buffer.position()) + } buffer.flip().asInstanceOf[ByteBuffer] } } @@ -129,7 +136,8 @@ private[sql] class MapColumnBuilder(dataType: MapType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 + val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 + val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { @@ -137,7 +145,7 @@ private[sql] object ColumnBuilder { } else { // grow in steps of initial size val capacity = orig.capacity() - val newSize = capacity + size.max(capacity / 8 + 1) + val newSize = capacity + size.max(capacity) val pos = orig.position() ByteBuffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index ba61003ba41c6..91a05650585cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -48,7 +48,7 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri private[sql] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - protected var sizeInBytes = 0L + private[sql] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 2cface61e59c4..ae77298e6da2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -133,7 +133,9 @@ private[sql] case class InMemoryRelation( }.toArray var rowCount = 0 - while (rowIterator.hasNext && rowCount < batchSize) { + var totalSize = 0L + while (rowIterator.hasNext && rowCount < batchSize + && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { val row = rowIterator.next() // Added for SPARK-6082. This assertion can be useful for scenarios when something @@ -147,8 +149,10 @@ private[sql] case class InMemoryRelation( s"\nRow content: $row") var i = 0 + totalSize = 0 while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) + totalSize += columnBuilders(i).columnStats.sizeInBytes i += 1 } rowCount += 1 From fa603e08de641df16d066302be5d5f92a60a923e Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 17 Nov 2015 20:51:20 +0000 Subject: [PATCH 066/149] [SPARK-11732] Removes some MiMa false positives This adds an extra filter for private or protected classes. We only filter for package private right now. Author: Timothy Hunter Closes #9697 from thunterdb/spark-11732. --- project/MimaExcludes.scala | 7 +------ .../scala/org/apache/spark/tools/GenerateMIMAIgnore.scala | 4 +++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 815951822c1ef..8b3bc96801e20 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,12 +54,7 @@ object MimaExcludes { MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.add"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.count"), + // MiMa does not deal properly with sealed traits ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") ) ++ Seq( diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index a0524cabff2d4..5155daa6d17bf 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -72,7 +72,9 @@ object GenerateMIMAIgnore { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = - isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) + isPackagePrivate(classSymbol) || + isPackagePrivateModule(moduleSymbol) || + classSymbol.isPrivate val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) /* Inner classes defined within a private[spark] class or object are effectively From 328eb49e6222271337e09188853b29c8f32fb157 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 17 Nov 2015 13:59:59 -0800 Subject: [PATCH 067/149] [SPARK-11729] Replace example code in ml-linear-methods.md using include_example JIRA link: https://issues.apache.org/jira/browse/SPARK-11729 Author: Xusen Yin Closes #9713 from yinxusen/SPARK-11729. --- docs/ml-linear-methods.md | 218 +----------------- ...LinearRegressionWithElasticNetExample.java | 65 ++++++ .../JavaLogisticRegressionSummaryExample.java | 84 +++++++ ...gisticRegressionWithElasticNetExample.java | 55 +++++ .../ml/linear_regression_with_elastic_net.py | 44 ++++ .../logistic_regression_with_elastic_net.py | 44 ++++ ...inearRegressionWithElasticNetExample.scala | 61 +++++ .../ml/LogisticRegressionSummaryExample.scala | 77 +++++++ ...isticRegressionWithElasticNetExample.scala | 53 +++++ 9 files changed, 491 insertions(+), 210 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java create mode 100644 examples/src/main/python/ml/linear_regression_with_elastic_net.py create mode 100644 examples/src/main/python/ml/logistic_regression_with_elastic_net.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 85edfd373465f..0c13d7d0c82b3 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -57,77 +57,15 @@ $\alpha$ and `regParam` corresponds to $\lambda$.
    -{% highlight scala %} -import org.apache.spark.ml.classification.LogisticRegression - -// Load training data -val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - -// Fit the model -val lrModel = lr.fit(training) - -// Print the coefficients and intercept for logistic regression -println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %}
    -{% highlight java %} -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class LogisticRegressionWithElasticNetExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("Logistic Regression with Elastic Net Example"); - - SparkContext sc = new SparkContext(conf); - SQLContext sql = new SQLContext(sc); - String path = "data/mllib/sample_libsvm_data.txt"; - - // Load training data - DataFrame training = sqlContext.read().format("libsvm").load(path); - - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8); - - // Fit the model - LogisticRegressionModel lrModel = lr.fit(training); - - // Print the coefficients and intercept for logistic regression - System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %}
    -{% highlight python %} -from pyspark.ml.classification import LogisticRegression - -# Load training data -training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) - -# Fit the model -lrModel = lr.fit(training) - -# Print the coefficients and intercept for logistic regression -print("Coefficients: " + str(lrModel.coefficients)) -print("Intercept: " + str(lrModel.intercept)) -{% endhighlight %} +{% include_example python/ml/logistic_regression_with_elastic_net.py %}
    @@ -152,33 +90,7 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: -{% highlight scala %} -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary - -// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -val trainingSummary = lrModel.summary - -// Obtain the objective per iteration. -val objectiveHistory = trainingSummary.objectiveHistory -objectiveHistory.foreach(loss => println(loss)) - -// Obtain the metrics useful to judge performance on test data. -// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a -// binary classification problem. -val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - -// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. -val roc = binarySummary.roc -roc.show() -println(binarySummary.areaUnderROC) - -// Set the model threshold to maximize F-Measure -val fMeasure = binarySummary.fMeasureByThreshold -val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) -val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). - select("threshold").head().getDouble(0) -lrModel.setThreshold(bestThreshold) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %}
    @@ -192,39 +104,7 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: -{% highlight java %} -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; -import org.apache.spark.sql.functions; - -// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); - -// Obtain the loss per iteration. -double[] objectiveHistory = trainingSummary.objectiveHistory(); -for (double lossPerIteration : objectiveHistory) { - System.out.println(lossPerIteration); -} - -// Obtain the metrics useful to judge performance on test data. -// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a -// binary classification problem. -BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; - -// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. -DataFrame roc = binarySummary.roc(); -roc.show(); -roc.select("FPR").show(); -System.out.println(binarySummary.areaUnderROC()); - -// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with -// this selected threshold. -DataFrame fMeasure = binarySummary.fMeasureByThreshold(); -double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); -double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). - select("threshold").head().getDouble(0); -lrModel.setThreshold(bestThreshold); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %}
    @@ -244,98 +124,16 @@ regression model and extracting model summary statistics.
    -{% highlight scala %} -import org.apache.spark.ml.regression.LinearRegression - -// Load training data -val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - -// Fit the model -val lrModel = lr.fit(training) - -// Print the coefficients and intercept for linear regression -println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") - -// Summarize the model over the training set and print out some metrics -val trainingSummary = lrModel.summary -println(s"numIterations: ${trainingSummary.totalIterations}") -println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") -trainingSummary.residuals.show() -println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") -println(s"r2: ${trainingSummary.r2}") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %}
    -{% highlight java %} -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.regression.LinearRegressionModel; -import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class LinearRegressionWithElasticNetExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("Linear Regression with Elastic Net Example"); - - SparkContext sc = new SparkContext(conf); - SQLContext sql = new SQLContext(sc); - String path = "data/mllib/sample_libsvm_data.txt"; - - // Load training data - DataFrame training = sqlContext.read().format("libsvm").load(path); - - LinearRegression lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8); - - // Fit the model - LinearRegressionModel lrModel = lr.fit(training); - - // Print the coefficients and intercept for linear regression - System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); - - // Summarize the model over the training set and print out some metrics - LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); - System.out.println("numIterations: " + trainingSummary.totalIterations()); - System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); - trainingSummary.residuals().show(); - System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); - System.out.println("r2: " + trainingSummary.r2()); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %}
    -{% highlight python %} -from pyspark.ml.regression import LinearRegression - -# Load training data -training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) - -# Fit the model -lrModel = lr.fit(training) - -# Print the coefficients and intercept for linear regression -print("Coefficients: " + str(lrModel.coefficients)) -print("Intercept: " + str(lrModel.intercept)) - -# Linear regression model summary is not yet supported in Python. -{% endhighlight %} +{% include_example python/ml/linear_regression_with_elastic_net.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..593f8fb3e9fe9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for linear regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java new file mode 100644 index 0000000000000..986f3b3b28d77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.functions; +// $example off$ + +public class JavaLogisticRegressionSummaryExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionSummaryExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary + // classification problem. + BinaryLogisticRegressionSummary binarySummary = + (BinaryLogisticRegressionSummary) trainingSummary; + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + DataFrame roc = binarySummary.roc(); + roc.show(); + roc.select("FPR").show(); + System.out.println(binarySummary.areaUnderROC()); + + // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with + // this selected threshold. + DataFrame fMeasure = binarySummary.fMeasureByThreshold(); + double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); + double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) + .select("threshold").head().getDouble(0); + lrModel.setThreshold(bestThreshold); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..1d28279d72a0a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + DataFrame training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for logistic regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py new file mode 100644 index 0000000000000..b0278276330c3 --- /dev/null +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import LinearRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LinearRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for linear regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py new file mode 100644 index 0000000000000..b0b1d27e13bb0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LogisticRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for logistic regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..5a51ece6f9ba7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.regression.LinearRegression +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object LinearRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for linear regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + + // Summarize the model over the training set and print out some metrics + val trainingSummary = lrModel.summary + println(s"numIterations: ${trainingSummary.totalIterations}") + println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") + trainingSummary.residuals.show() + println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") + println(s"r2: ${trainingSummary.r2}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala new file mode 100644 index 0000000000000..4c420421b670e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.max +import org.apache.spark.{SparkConf, SparkContext} + +object LogisticRegressionSummaryExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionSummaryExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + import sqlCtx.implicits._ + + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration. + val objectiveHistory = trainingSummary.objectiveHistory + objectiveHistory.foreach(loss => println(loss)) + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a + // binary classification problem. + val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + val roc = binarySummary.roc + roc.show() + println(binarySummary.areaUnderROC) + + // Set the model threshold to maximize F-Measure + val fMeasure = binarySummary.fMeasureByThreshold + val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) + val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) + .select("threshold").head().getDouble(0) + lrModel.setThreshold(bestThreshold) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..9ee995b52c90b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object LogisticRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for logistic regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From 6eb7008b7f33a36b06d0615b68cc21ed90ad1d8a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 14:03:49 -0800 Subject: [PATCH 068/149] [SPARK-11763][ML] Add save,load to LogisticRegression Estimator Add save/load to LogisticRegression Estimator, and refactor tests a little to make it easier to add similar support to other Estimator, Model pairs. Moved LogisticRegressionReader/Writer to within LogisticRegressionModel CC: mengxr Author: Joseph K. Bradley Closes #9749 from jkbradley/lr-io-2. --- .../classification/LogisticRegression.scala | 91 ++++++++++--------- .../org/apache/spark/ml/util/ReadWrite.scala | 1 + .../org/apache/spark/ml/PipelineSuite.scala | 7 -- .../ml/classification/ClassifierSuite.scala | 32 +++++++ .../LogisticRegressionSuite.scala | 37 ++++++-- .../ProbabilisticClassifierSuite.scala | 14 +++ .../spark/ml/util/DefaultReadWriteTest.scala | 50 +++++++++- 7 files changed, 173 insertions(+), 59 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a88f52674102c..71c2533bcbf47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Logging { + with LogisticRegressionParams with Writable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,6 +385,12 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object LogisticRegression extends Readable[LogisticRegression] { + override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] } /** @@ -517,61 +523,62 @@ class LogisticRegressionModel private[ml] ( * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionWriter(this) -} - - -/** [[Writer]] instance for [[LogisticRegressionModel]] */ -private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel) - extends Writer with Logging { - - private case class Data( - numClasses: Int, - numFeatures: Int, - intercept: Double, - coefficients: Vector) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: numClasses, numFeatures, intercept, coefficients - val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, - instance.coefficients) - val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) - } + override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } object LogisticRegressionModel extends Readable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader + override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader override def load(path: String): LogisticRegressionModel = read.load(path) -} + /** [[Writer]] instance for [[LogisticRegressionModel]] */ + private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + extends Writer with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } -private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] { + private[classification] class LogisticRegressionModelReader + extends Reader[LogisticRegressionModel] { - /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" - override def load(path: String): LogisticRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - .select("numClasses", "numFeatures", "intercept", "coefficients").head() - // We will need numClasses, numFeatures in the future for multinomial logreg support. - // val numClasses = data.getInt(0) - // val numFeatures = data.getInt(1) - val intercept = data.getDouble(2) - val coefficients = data.getAs[Vector](3) - val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) - model + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 3169c9e9af5be..dddb72af5ba78 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -217,6 +217,7 @@ private[ml] object DefaultParamsWriter { * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type + * TODO: Consider adding check for correct class name. */ private[ml] class DefaultParamsReader[T] extends Reader[T] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 484026b1ba9ad..7f5c3895acb0c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(pipeline2.stages(0).isInstanceOf[WritableStage]) val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] assert(writableStage.getIntParam === writableStage2.getIntParam) - - val path = new File(tempDir, pipeline.uid).getPath - val stagesDir = new Path(path, "stages").toString - val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) - assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), - s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + - s" to be saved to path: $expectedStagePath") } test("PipelineModel read/write: getStagePath") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 0000000000000..d0e3fe7ad14b6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +object ClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "rawPredictionCol" -> "myRawPrediction" + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 51b06b7eb6d53..48ce1bb630685 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -873,15 +873,34 @@ class LogisticRegressionSuite } test("read/write") { - // Set some Params to make sure set Params are serialized. + def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + assert(model.numClasses === model2.numClasses) + assert(model.numFeatures === model2.numFeatures) + } val lr = new LogisticRegression() - .setElasticNetParam(0.1) - .setMaxIter(2) - .fit(dataset) - val lr2 = testDefaultReadWrite(lr) - assert(lr.intercept === lr2.intercept) - assert(lr.coefficients.toArray === lr2.coefficients.toArray) - assert(lr.numClasses === lr2.numClasses) - assert(lr.numFeatures === lr2.numFeatures) + testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, + checkModelData) } } + +object LogisticRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6), + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> false, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index fb5f00e0646c6..cfa75ecf387cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } } + +object ProbabilisticClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6) + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index c37f0503f1332..dd1e8acce9418 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,13 +22,17 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Model, Estimator} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. + * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name + * in order to avoid conflicts from multiple calls to this method. * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type @@ -38,7 +42,10 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance: T, testParams: Boolean = true): T = { val uid = instance.uid - val path = new File(tempDir, uid).getPath + val subdirName = Identifiable.randomUID("test") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath instance.save(path) intercept[IOException] { @@ -69,6 +76,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(another.uid === instance.uid) another } + + /** + * Default test for Estimator, Model pairs: + * - Explicitly set Params, and train model + * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Check Params on Estimator and Model + * + * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * @param estimator Estimator to test + * @param dataset Dataset to pass to [[Estimator.fit()]] + * @param testParams Set of [[Param]] values to set in estimator + * @param checkModelData Method which takes the original and loaded [[Model]] and compares their + * data. This method does not need to check [[Param]] values. + * @tparam E Type of [[Estimator]] + * @tparam M Type of [[Model]] produced by estimator + */ + def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + estimator: E, + dataset: DataFrame, + testParams: Map[String, Any], + checkModelData: (M, M) => Unit): Unit = { + // Set some Params to make sure set Params are serialized. + testParams.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + val model = estimator.fit(dataset) + + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator) + testParams.foreach { case (p, v) => + val param = estimator.getParam(p) + assert(estimator.get(param).get === estimator2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + testParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + } } class MyParams(override val uid: String) extends Params with Writable { From 3e9e6380236985ec5b51b459f8c61f964a76ff8b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 17 Nov 2015 14:04:49 -0800 Subject: [PATCH 069/149] [SPARK-11764][ML] make Param.jsonEncode/jsonDecode support Vector This PR makes the default read/write work with simple transformers/estimators that have params of type `Param[Vector]`. jkbradley Author: Xiangrui Meng Closes #9776 from mengxr/SPARK-11764. --- .../org/apache/spark/ml/param/params.scala | 12 ++++++++-- .../apache/spark/ml/param/ParamsSuite.scala | 22 +++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c9325709187c5..d182b0a98896c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -88,9 +89,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali value match { case x: String => compact(render(JString(x))) + case v: Vector => + v.toJson case _ => throw new NotImplementedError( - "The default jsonEncode only supports string. " + + "The default jsonEncode only supports string and vector. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } } @@ -100,9 +103,14 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali parse(json) match { case JString(x) => x.asInstanceOf[T] + case JObject(v) => + val keys = v.map(_._1) + assert(keys.contains("type") && keys.contains("values"), + s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") + Vectors.fromJson(json).asInstanceOf[T] case _ => throw new NotImplementedError( - "The default jsonDecode only supports string. " + + "The default jsonDecode only supports string and vector. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index eeb03dba2f825..a1878be747ceb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -80,7 +81,7 @@ class ParamsSuite extends SparkFunSuite { } } - { // StringParam + { // Param[String] val param = new Param[String](dummy, "name", "doc") // Currently we do not support null. for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { @@ -89,6 +90,19 @@ class ParamsSuite extends SparkFunSuite { } } + { // Param[Vector] + val param = new Param[Vector](dummy, "name", "doc") + val values = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0, 2.0), + Vectors.sparse(0, Array.empty, Array.empty), + Vectors.sparse(2, Array(1), Array(2.0))) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + { // IntArrayParam val param = new IntArrayParam(dummy, "name", "doc") val values: Seq[Array[Int]] = Seq( @@ -138,7 +152,7 @@ class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() val uid = solver.uid - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} assert(maxIter.name === "maxIter") assert(maxIter.doc === "maximum number of iterations (>= 0)") @@ -181,7 +195,7 @@ class ParamsSuite extends SparkFunSuite { test("param map") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} val map0 = ParamMap.empty @@ -220,7 +234,7 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{handleInvalid, maxIter, inputCol} + import solver.{handleInvalid, inputCol, maxIter} val params = solver.params assert(params.length === 3) From 936bc0bcbf957fa1d7cb5cfe88d628c830df5981 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 17 Nov 2015 14:23:28 -0800 Subject: [PATCH 070/149] [SPARK-11786][CORE] Tone down messages from akka error monitor. There events happen normally during the app's lifecycle, so printing out ERROR logs all the time is misleading, and can actually affect usability of interactive shells. Author: Marcelo Vanzin Closes #9772 from vanzin/SPARK-11786. --- core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 3fad595a0d0b0..059a7e10ec12f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -263,7 +263,7 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause) } } From 928d631625297857fb6998fbeb0696917fbfd60f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 17 Nov 2015 14:48:29 -0800 Subject: [PATCH 071/149] [SPARK-11740][STREAMING] Fix the race condition of two checkpoints in a batch We will do checkpoint when generating a batch and completing a batch. When the processing time of a batch is greater than the batch interval, checkpointing for completing an old batch may run after checkpointing for generating a new batch. If this happens, checkpoint of an old batch actually has the latest information, so we want to recovery from it. This PR will use the latest checkpoint time as the file name, so that we can always recovery from the latest checkpoint file. Author: Shixiong Zhu Closes #9707 from zsxwing/fix-checkpoint. --- python/pyspark/streaming/tests.py | 9 +++---- .../apache/spark/streaming/Checkpoint.scala | 18 +++++++++++-- .../spark/streaming/CheckpointSuite.scala | 27 +++++++++++++++++-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2983028413bb8..ff95639146e59 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -753,7 +753,6 @@ def tearDown(self): if self.cpd is not None: shutil.rmtree(self.cpd) - @unittest.skip("Enable it when we fix the checkpoint bug") def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -822,11 +821,11 @@ def check_output(n): # Verify that getOrCreate() uses existing SparkContext self.ssc.stop(True, True) time.sleep(1) - sc = SparkContext(SparkConf()) + self.sc = SparkContext(conf=SparkConf()) self.setupCalled = False self.ssc = StreamingContext.getOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) - self.assertTrue(self.ssc.sparkContext == sc) + self.assertTrue(self.ssc.sparkContext == self.sc) # Verify the getActiveOrCreate() recovers from checkpoint files self.ssc.stop(True, True) @@ -845,11 +844,11 @@ def check_output(n): # Verify that getActiveOrCreate() uses existing SparkContext self.ssc.stop(True, True) time.sleep(1) - self.sc = SparkContext(SparkConf()) + self.sc = SparkContext(conf=SparkConf()) self.setupCalled = False self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) - self.assertTrue(self.ssc.sparkContext == sc) + self.assertTrue(self.ssc.sparkContext == self.sc) # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files self.ssc.stop(True, True) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 0cd55d9aec2cd..fd0e8d5d690b6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -187,16 +187,30 @@ class CheckpointWriter( private var stopped = false private var fs_ : FileSystem = _ + @volatile private var latestCheckpointTime: Time = null + class CheckpointWriteHandler( checkpointTime: Time, bytes: Array[Byte], clearCheckpointDataLater: Boolean) extends Runnable { def run() { + if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) { + latestCheckpointTime = checkpointTime + } var attempts = 0 val startTime = System.currentTimeMillis() val tempFile = new Path(checkpointDir, "temp") - val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) - val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + // We will do checkpoint when generating a batch and completing a batch. When the processing + // time of a batch is greater than the batch interval, checkpointing for completing an old + // batch may run after checkpointing of a new batch. If this happens, checkpoint of an old + // batch actually has the latest information, so we want to recovery from it. Therefore, we + // also use the latest checkpoint time as the file name, so that we can recovery from the + // latest checkpoint file. + // + // Note: there is only one thread writting the checkpoint files, so we don't need to worry + // about thread-safety. + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, latestCheckpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, latestCheckpointTime) while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 84f5294aa39cc..b1cbc7163bee3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} -import org.apache.spark.TestUtils import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag @@ -30,11 +29,13 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.TestUtils import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} +import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} /** @@ -611,6 +612,28 @@ class CheckpointSuite extends TestSuiteBase { assert(ois.readObject().asInstanceOf[Class[_]].getName == "[LtestClz;") } + test("SPARK-11267: the race condition of two checkpoints in a batch") { + val jobGenerator = mock(classOf[JobGenerator]) + val checkpointDir = Utils.createTempDir().toString + val checkpointWriter = + new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration()) + val bytes1 = Array.fill[Byte](10)(1) + new checkpointWriter.CheckpointWriteHandler( + Time(2000), bytes1, clearCheckpointDataLater = false).run() + val bytes2 = Array.fill[Byte](10)(2) + new checkpointWriter.CheckpointWriteHandler( + Time(1000), bytes2, clearCheckpointDataLater = true).run() + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir).reverse.map { path => + new File(path.toUri) + } + assert(checkpointFiles.size === 2) + // Although bytes2 was written with an old time, it contains the latest status, so we should + // try to read from it at first. + assert(Files.toByteArray(checkpointFiles(0)) === bytes2) + assert(Files.toByteArray(checkpointFiles(1)) === bytes1) + checkpointWriter.stop() + } + /** * Tests a streaming operation under checkpointing, by restarting the operation * from checkpoint file and verifying whether the final output is correct. From 965245d087c18edc6c3d5baddeaf83163e32e330 Mon Sep 17 00:00:00 2001 From: Grace Date: Tue, 17 Nov 2015 15:43:35 -0800 Subject: [PATCH 072/149] [SPARK-9552] Add force control for killExecutors to avoid false killing for those busy executors By using the dynamic allocation, sometimes it occurs false killing for those busy executors. Some executors with assignments will be killed because of being idle for enough time (say 60 seconds). The root cause is that the Task-Launch listener event is asynchronized. For example, some executors are under assigning tasks, but not sending out the listener notification yet. Meanwhile, the dynamic allocation's executor idle time is up (e.g., 60 seconds). It will trigger killExecutor event at the same time. 1. the timer expiration starts before the listener event arrives. 2. Then, the task is going to run on top of that killed/killing executor. It will lead to task failure finally. Here is the proposal to fix it. We can add the force control for killExecutor. If the force control is not set (i.e., false), we'd better to check if the executor under killing is idle or busy. If the current executor has some assignment, we should not kill that executor and return back false (to indicate killing failure). In dynamic allocation, we'd better to turn off force killing (i.e., force = false), we will meet killing failure if tries to kill a busy executor. And then, the executor timer won't be invalid. Later on, the task assignment event arrives, we can remove the idle timer accordingly. So that we can avoid false killing for those busy executors in dynamic allocation. For the rest of usages, the end users can decide if to use force killing or not by themselves. If to turn on that option, the killExecutor will do the action without any status checking. Author: Grace Author: Andrew Or Author: Jie Huang Closes #7888 from GraceH/forcekill. --- .../spark/ExecutorAllocationManager.scala | 1 + .../scala/org/apache/spark/SparkContext.scala | 4 +- .../spark/scheduler/TaskSchedulerImpl.scala | 27 +++++++--- .../CoarseGrainedSchedulerBackend.scala | 13 +++-- .../StandaloneDynamicAllocationSuite.scala | 52 ++++++++++++++++++- 5 files changed, 82 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index b93536e6536e2..6419218f47c85 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -509,6 +509,7 @@ private[spark] class ExecutorAllocationManager( private def onExecutorBusy(executorId: String): Unit = synchronized { logDebug(s"Clearing idle timer for $executorId because it is now running a task") removeTimes.remove(executorId) + executorsPendingToRemove.remove(executorId) } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4bbd0b038c00f..b5645b08f92d4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1461,7 +1461,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(executorIds) + b.killExecutors(executorIds, replace = false, force = true) case _ => logWarning("Killing executors is only supported in coarse-grained mode") false @@ -1499,7 +1499,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = true) + b.killExecutors(Seq(executorId), replace = true, force = true) case _ => logWarning("Killing executors is only supported in coarse-grained mode") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 5f136690f456c..bf0419db1f75e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -87,8 +87,8 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] + // Number of tasks running on each executor + private val executorIdToTaskCount = new HashMap[String, Int] // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -254,6 +254,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId + executorIdToTaskCount(execId) += 1 executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) @@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl( var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host - activeExecutorIds += o.executorId + executorIdToTaskCount.getOrElseUpdate(o.executorId, 0) if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -331,7 +332,8 @@ private[spark] class TaskSchedulerImpl( if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { + + if (executorIdToTaskCount.contains(execId)) { removeExecutor(execId, SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) @@ -341,7 +343,11 @@ private[spark] class TaskSchedulerImpl( case Some(taskSet) => if (TaskState.isFinished(state)) { taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid) + taskIdToExecutorId.remove(tid).foreach { execId => + if (executorIdToTaskCount.contains(execId)) { + executorIdToTaskCount(execId) -= 1 + } + } } if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) @@ -462,7 +468,7 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (activeExecutorIds.contains(executorId)) { + if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId, reason) @@ -498,7 +504,8 @@ private[spark] class TaskSchedulerImpl( * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - activeExecutorIds -= executorId + executorIdToTaskCount -= executorId + val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) execs -= executorId @@ -535,7 +542,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) + executorIdToTaskCount.contains(execId) + } + + def isExecutorBusy(execId: String): Boolean = synchronized { + executorIdToTaskCount.getOrElse(execId, -1) > 0 } // By default, rack is unknown diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 3373caf0d15eb..6f0c910c009a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -453,7 +453,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { - killExecutors(executorIds, replace = false) + killExecutors(executorIds, replace = false, force = false) } /** @@ -461,9 +461,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones + * @param force whether to force kill busy executors * @return whether the kill request is acknowledged. */ - final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { + final def killExecutors( + executorIds: Seq[String], + replace: Boolean, + force: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) unknownExecutors.foreach { id => @@ -471,7 +475,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // If an executor is already pending to be removed, do not kill it again (SPARK-9795) - val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) + val executorsToKill = knownExecutors + .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => force || !scheduler.isExecutorBusy(id) } executorsPendingToRemove ++= executorsToKill // If we do not wish to replace the executors we kill, sync the target number of executors diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index d145e78834b1b..2fa795f846667 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy +import scala.collection.mutable import scala.concurrent.duration._ import org.mockito.Mockito.{mock, when} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -29,6 +30,7 @@ import org.apache.spark.deploy.master.ApplicationInfo import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor @@ -38,7 +40,8 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterE class StandaloneDynamicAllocationSuite extends SparkFunSuite with LocalSparkContext - with BeforeAndAfterAll { + with BeforeAndAfterAll + with PrivateMethodTester { private val numWorkers = 2 private val conf = new SparkConf() @@ -404,6 +407,41 @@ class StandaloneDynamicAllocationSuite assert(apps.head.getExecutorLimit === 1) } + test("disable force kill for busy executors (SPARK-9552)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + var apps = getApplications() + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + val executors = getExecutorIds(sc) + assert(executors.size === 2) + + // simulate running a task on the executor + val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount) + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val executorIdToTaskCount = taskScheduler invokePrivate getMap() + executorIdToTaskCount(executors.head) = 1 + // kill the busy executor without force; this should fail + assert(killExecutor(sc, executors.head, force = false)) + apps = getApplications() + assert(apps.head.executors.size === 2) + + // force kill busy executor + assert(killExecutor(sc, executors.head, force = true)) + apps = getApplications() + // kill executor successfully + assert(apps.head.executors.size === 1) + + } + // =============================== // | Utility methods for testing | // =============================== @@ -455,6 +493,16 @@ class StandaloneDynamicAllocationSuite sc.killExecutors(getExecutorIds(sc).take(n)) } + /** Kill the given executor, specifying whether to force kill it. */ + private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Boolean = { + syncExecutors(sc) + sc.schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = false, force) + case _ => fail("expected coarse grained scheduler") + } + } + /** * Return a list of executor IDs belonging to this application. * From e29656f8e7fa19686b448292e20d8bbf07ab9f11 Mon Sep 17 00:00:00 2001 From: Rohan Bhanderi Date: Tue, 17 Nov 2015 15:45:39 -0800 Subject: [PATCH 073/149] [MINOR] Correct comments in JavaDirectKafkaWordCount Author: Rohan Bhanderi Closes #9781 from RohanBhanderi/patch-3. --- .../examples/streaming/JavaDirectKafkaWordCount.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index bab9f2478e779..f9a5e7f69ffe1 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -35,12 +35,12 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: DirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers * is a list of one or more kafka topics to consume from * * Example: - * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 */ public final class JavaDirectKafkaWordCount { @@ -48,7 +48,7 @@ public final class JavaDirectKafkaWordCount { public static void main(String[] args) { if (args.length < 2) { - System.err.println("Usage: DirectKafkaWordCount \n" + + System.err.println("Usage: JavaDirectKafkaWordCount \n" + " is a list of one or more Kafka brokers\n" + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); @@ -59,7 +59,7 @@ public static void main(String[] args) { String brokers = args[0]; String topics = args[1]; - // Create context with 2 second batch interval + // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); From 3720b1480c7d050ca20f98d65762224ae5639607 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 17 Nov 2015 15:47:39 -0800 Subject: [PATCH 074/149] [SPARK-11790][STREAMING][TESTS] Increase the connection timeout Sometimes, EmbeddedZookeeper may need more than 6 seconds to setup up in a slow Jenkins worker. So just increase the timeout, it won't increase the test time if the test passes. Author: Shixiong Zhu Closes #9778 from zsxwing/SPARK-11790. --- .../scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index c9fd715d3d554..86394ea8a685e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -52,7 +52,7 @@ private[kafka] class KafkaTestUtils extends Logging { // Zookeeper related configurations private val zkHost = "localhost" private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 + private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 6000 private var zookeeper: EmbeddedZookeeper = _ From 52c734b589277267be07e245c959199db92aa189 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 17 Nov 2015 15:51:03 -0800 Subject: [PATCH 075/149] [SPARK-11771][YARN][TRIVIAL] maximum memory in yarn is controlled by two params have both in error msg When we exceed the max memory tell users to increase both params instead of just the one. Author: Holden Karau Closes #9758 from holdenk/SPARK-11771-maximum-memory-in-yarn-is-controlled-by-two-params-have-both-in-error-msg. --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a3f33d80184a3..ba799884f5689 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -258,7 +258,8 @@ private[spark] class Client( if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { From b362d50fca30693f97bd859984157bb8a76d48a1 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Tue, 17 Nov 2015 15:57:43 -0800 Subject: [PATCH 076/149] [SPARK-11726] Throw exception on timeout when waiting for REST server response Author: Jacek Lewandowski Closes #9692 from jacek-lewandowski/SPARK-11726. --- .../spark/deploy/rest/RestSubmissionClient.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 957a928bc402b..f0dd667ea1b26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -19,16 +19,19 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} import java.net.{ConnectException, HttpURLConnection, SocketException, URL} +import java.util.concurrent.TimeoutException import javax.servlet.http.HttpServletResponse import scala.collection.mutable +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} /** * A client that submits applications to a [[RestSubmissionServer]]. @@ -225,7 +228,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { * Exposed for testing. */ private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { - try { + import scala.concurrent.ExecutionContext.Implicits.global + val responseFuture = Future { val dataStream = if (connection.getResponseCode == HttpServletResponse.SC_OK) { connection.getInputStream @@ -251,11 +255,15 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { throw new SubmitRestProtocolException( s"Message received from server was not a response:\n${unexpected.toJson}") } - } catch { + } + + try { Await.result(responseFuture, 10.seconds) } catch { case unreachable @ (_: FileNotFoundException | _: SocketException) => throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => throw new SubmitRestProtocolException("Malformed response received from server", malformed) + case timeout: TimeoutException => + throw new SubmitRestConnectionException("No response from server", timeout) } } From 75a292291062783129d02607302f91c85655975e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Nov 2015 16:57:52 -0800 Subject: [PATCH 077/149] [SPARK-9065][STREAMING][PYSPARK] Add MessageHandler for Kafka Python API Fixed the merge conflicts in #7410 Closes #7410 Author: Shixiong Zhu Author: jerryshao Author: jerryshao Closes #9742 from zsxwing/pr7410. --- .../spark/streaming/kafka/KafkaUtils.scala | 245 ++++++++++++------ project/MimaExcludes.scala | 6 + python/pyspark/streaming/kafka.py | 111 +++++++- python/pyspark/streaming/tests.py | 35 +++ 4 files changed, 299 insertions(+), 98 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 3128222077537..ad2fb8aa5f24c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,25 +17,29 @@ package org.apache.spark.streaming.kafka +import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong} import java.util.{List => JList, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import com.google.common.base.Charsets.UTF_8 import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} +import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} -import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} object KafkaUtils { /** @@ -184,6 +188,27 @@ object KafkaUtils { } } + private[kafka] def getFromOffsets( + kc: KafkaCluster, + kafkaParams: Map[String, String], + topics: Set[String] + ): Map[TopicAndPartition, Long] = { + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val result = for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + } + KafkaCluster.checkErrors(result) + } + /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -246,7 +271,7 @@ object KafkaUtils { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) - }.toMap + } } val cleanedHandler = sc.clean(messageHandler) checkOffsets(kc, offsetRanges) @@ -406,23 +431,9 @@ object KafkaUtils { ): InputDStream[(K, V)] = { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) val kc = new KafkaCluster(kafkaParams) - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - - val result = for { - topicPartitions <- kc.getPartitions(topics).right - leaderOffsets <- (if (reset == Some("smallest")) { - kc.getEarliestLeaderOffsets(topicPartitions) - } else { - kc.getLatestLeaderOffsets(topicPartitions) - }).right - } yield { - val fromOffsets = leaderOffsets.map { case (tp, lo) => - (tp, lo.offset) - } - new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( - ssc, kafkaParams, fromOffsets, messageHandler) - } - KafkaCluster.checkErrors(result) + val fromOffsets = getFromOffsets(kc, kafkaParams, topics) + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) } /** @@ -550,6 +561,8 @@ object KafkaUtils { * takes care of known parameters instead of passing them from Python */ private[kafka] class KafkaUtilsPythonHelper { + import KafkaUtilsPythonHelper._ + def createStream( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], @@ -566,86 +579,92 @@ private[kafka] class KafkaUtilsPythonHelper { storageLevel) } - def createRDD( + def createRDDWithoutMessageHandler( jsc: JavaSparkContext, kafkaParams: JMap[String, String], offsetRanges: JList[OffsetRange], - leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler)) + } - val jrdd = KafkaUtils.createRDD[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jsc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), - leaders, - messageHandler - ) - new JavaPairRDD(jrdd.rdd) + def createRDDWithMessageHandler( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata( + mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler). + mapPartitions(picklerIterator) + new JavaRDD(rdd) } - def createDirectStream( + private def createRDD[V: ClassTag]( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = { + KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jsc.sc, + kafkaParams.asScala.toMap, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders.asScala.toMap, + messageHandler + ) + } + + def createDirectStreamWithoutMessageHandler( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler)) + } + + def createDirectStreamWithMessageHandler( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong] - ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler). + mapPartitions(picklerIterator) + new JavaDStream(stream) + } - if (!fromOffsets.isEmpty) { + private def createDirectStream[V: ClassTag]( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = { + + val currentFromOffsets = if (!fromOffsets.isEmpty) { val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) if (topicsFromOffsets != topics.asScala.toSet) { throw new IllegalStateException( s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } - } - - if (fromOffsets.isEmpty) { - KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - kafkaParams, - topics) + Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*) } else { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } - - val jstream = KafkaUtils.createDirectStream[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - fromOffsets, - messageHandler) - new JavaPairInputDStream(jstream.inputDStream) + val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*)) + KafkaUtils.getFromOffsets( + kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*)) } + + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jssc.ssc, + Map(kafkaParams.asScala.toSeq: _*), + Map(currentFromOffsets.toSeq: _*), + messageHandler) } def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong @@ -669,3 +688,57 @@ private[kafka] class KafkaUtilsPythonHelper { kafkaRDD.offsetRanges.toSeq.asJava } } + +private object KafkaUtilsPythonHelper { + private var initialized = false + + def initialize(): Unit = { + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new PythonMessageAndMetadataPickler().register() + initialized = true + } + } + } + + initialize() + + def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = { + new SerDeUtil.AutoBatchedPickler(iter) + } + + case class PythonMessageAndMetadata( + topic: String, + partition: JInt, + offset: JLong, + key: Array[Byte], + message: Array[Byte]) + + class PythonMessageAndMetadataPickler extends IObjectPickler { + private val module = "pyspark.streaming.kafka" + + def register(): Unit = { + Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this) + Pickler.registerCustomPickler(this.getClass, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler) { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8)) + } else { + pickler.save(this) + val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata] + out.write(Opcodes.MARK) + pickler.save(msgAndMetaData.topic) + pickler.save(msgAndMetaData.partition) + pickler.save(msgAndMetaData.offset) + pickler.save(msgAndMetaData.key) + pickler.save(msgAndMetaData.message) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8b3bc96801e20..eb70d27c34c20 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -136,6 +136,12 @@ object MimaExcludes { // SPARK-11766 add toJson to Vector ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Vector.toJson") + ) ++ Seq( + // SPARK-9065 Support message handler in Kafka Python API + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 06e159172ab51..cdf97ec73aaf9 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -19,12 +19,14 @@ from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \ + NoOpSerializer from pyspark.streaming import DStream from pyspark.streaming.dstream import TransformedDStream from pyspark.streaming.util import TransformFunction -__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange', + 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -82,7 +84,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, @staticmethod def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -107,6 +110,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, point of the stream. :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A DStream object """ if fromOffsets is None: @@ -116,6 +121,14 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) + + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + try: helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") @@ -123,20 +136,28 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, jfromOffsets = dict([(k._jTopicAndPartition(helper), v) for (k, v) in fromOffsets.items()]) - jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + if messageHandler is None: + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + func = funcWithoutMessageHandler + jstream = helper.createDirectStreamWithoutMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) + else: + ser = AutoBatchedSerializer(PickleSerializer()) + func = funcWithMessageHandler + jstream = helper.createDirectStreamWithMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): KafkaUtils._printErrorMsg(ssc.sparkContext) raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) \ - .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + stream = DStream(jstream, ssc, ser).map(func) return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod def createRDD(sc, kafkaParams, offsetRanges, leaders=None, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -149,6 +170,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A RDD object """ if leaders is None: @@ -158,6 +181,14 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, if not isinstance(offsetRanges, list): raise TypeError("offsetRanges should be list") + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) + + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + try: helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") @@ -165,15 +196,21 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] jleaders = dict([(k._jTopicAndPartition(helper), v._jBroker(helper)) for (k, v) in leaders.items()]) - jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + if messageHandler is None: + jrdd = helper.createRDDWithoutMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler) + else: + jrdd = helper.createRDDWithMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + rdd = RDD(jrdd, sc).map(funcWithMessageHandler) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): KafkaUtils._printErrorMsg(sc) raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) - return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) + return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer) @staticmethod def _printErrorMsg(sc): @@ -365,3 +402,53 @@ def _jdstream(self): dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val + + +class KafkaMessageAndMetadata(object): + """ + Kafka message and metadata information. Including topic, partition, offset and message + """ + + def __init__(self, topic, partition, offset, key, message): + """ + Python wrapper of Kafka MessageAndMetadata + :param topic: topic name of this Kafka message + :param partition: partition id of this Kafka message + :param offset: Offset of this Kafka message in the specific partition + :param key: key payload of this Kafka message, can be null if this Kafka message has no key + specified, the return data is undecoded bytearry. + :param message: actual message payload of this Kafka message, the return data is + undecoded bytearray. + """ + self.topic = topic + self.partition = partition + self.offset = offset + self._rawKey = key + self._rawMessage = message + self._keyDecoder = utf8_decoder + self._valueDecoder = utf8_decoder + + def __str__(self): + return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \ + % (self.topic, self.partition, self.offset) + + def __repr__(self): + return self.__str__() + + def __reduce__(self): + return (KafkaMessageAndMetadata, + (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage)) + + def _set_key_decoder(self, decoder): + self._keyDecoder = decoder + + def _set_value_decoder(self, decoder): + self._valueDecoder = decoder + + @property + def key(self): + return self._keyDecoder(self._rawKey) + + @property + def message(self): + return self._valueDecoder(self._rawMessage) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index ff95639146e59..0bcd1f15532b5 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1042,6 +1042,41 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_message_handler(self): + """Test Python direct Kafka RDD MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 1, "c": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, + messageHandler=getKeyAndDoubleMessage) + self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_message_handler(self): + """Test the Python direct Kafka stream MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, + messageHandler=getKeyAndDoubleMessage) + self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds From ed8d1531f93f697c54bbaecefe08c37c32b0d391 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 17 Nov 2015 19:02:44 -0800 Subject: [PATCH 078/149] [SPARK-11793][SQL] Dataset should set the resolved encoders internally for maps. I also wrote a test case -- but unfortunately the test case is not working due to SPARK-11795. Author: Reynold Xin Closes #9784 from rxin/SPARK-11503. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4cc3aa2465f2e..bd01dd4dc5799 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -199,11 +199,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + encoderFor[T].assertUnresolved() new Dataset[U]( sqlContext, MapPartitions[T, U]( func, - encoderFor[T], + resolvedTEncoder, encoderFor[U], encoderFor[U].schema.toAttributes, logicalPlan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c23dd46d3767b..a3922340ccc9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -73,6 +73,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } + ignore("Dataset should set the resolved encoders internally for maps") { + // TODO: Enable this once we fix SPARK-11793. + val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() + .map(c => ClassData(c.a, c.b + 1)) + .groupBy(p => p).count() + + checkAnswer( + ds, + (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + } + test("select") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( From bf25f9bdfc7bd8533890c7df1b35afa912dc6d3d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 17 Nov 2015 19:39:39 -0800 Subject: [PATCH 079/149] [SPARK-11016] Move RoaringBitmap to explicit Kryo serializer Fix the serialization of RoaringBitmap with Kyro serializer This PR came from https://github.com/metamx/spark/pull/1, thanks to drcrallen Author: Davies Liu Author: Charles Allen Closes #9748 from davies/SPARK-11016. --- .../spark/serializer/KryoSerializer.scala | 64 ++++++++++++++++--- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index c5195c1143a8f..1bcb3175a3016 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{EOFException, IOException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream, DataInput, DataOutput} import java.nio.ByteBuffer import javax.annotation.Nullable @@ -25,12 +25,12 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import com.esotericsoftware.kryo.{Kryo, KryoException} +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} -import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} +import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast @@ -94,6 +94,9 @@ class KryoSerializer(conf: SparkConf) for (cls <- KryoSerializer.toRegister) { kryo.register(cls) } + for ((cls, ser) <- KryoSerializer.toRegisterSerializer) { + kryo.register(cls, ser) + } // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) @@ -363,12 +366,6 @@ private[serializer] object KryoSerializer { classOf[StorageLevel], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], - classOf[RoaringBitmap], - classOf[RoaringArray], - classOf[RoaringArray.Element], - classOf[Array[RoaringArray.Element]], - classOf[ArrayContainer], - classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], @@ -377,6 +374,55 @@ private[serializer] object KryoSerializer { classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) + + private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( + classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { + override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + } + override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + ret + } + } + ) +} + +private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput { + override def readLong(): Long = input.readLong() + override def readChar(): Char = input.readChar() + override def readFloat(): Float = input.readFloat() + override def readByte(): Byte = input.readByte() + override def readShort(): Short = input.readShort() + override def readUTF(): String = input.readString() // readString in kryo does utf8 + override def readInt(): Int = input.readInt() + override def readUnsignedShort(): Int = input.readShortUnsigned() + override def skipBytes(n: Int): Int = input.skip(n.toLong).toInt + override def readFully(b: Array[Byte]): Unit = input.read(b) + override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) + override def readLine(): String = throw new UnsupportedOperationException("readLine") + override def readBoolean(): Boolean = input.readBoolean() + override def readUnsignedByte(): Int = input.readByteUnsigned() + override def readDouble(): Double = input.readDouble() +} + +private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput { + override def writeFloat(v: Float): Unit = output.writeFloat(v) + // There is no "readChars" counterpart, except maybe "readLine", which is not supported + override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") + override def writeDouble(v: Double): Unit = output.writeDouble(v) + override def writeUTF(s: String): Unit = output.writeString(s) // writeString in kryo does UTF8 + override def writeShort(v: Int): Unit = output.writeShort(v) + override def writeInt(v: Int): Unit = output.writeInt(v) + override def writeBoolean(v: Boolean): Unit = output.writeBoolean(v) + override def write(b: Int): Unit = output.write(b) + override def write(b: Array[Byte]): Unit = output.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len) + override def writeBytes(s: String): Unit = output.writeString(s) + override def writeChar(v: Int): Unit = output.writeChar(v.toChar) + override def writeLong(v: Long): Unit = output.writeLong(v) + override def writeByte(v: Int): Unit = output.writeByte(v) } /** From e33053ee0015025bbcfddb20cc9216c225bbe624 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 17 Nov 2015 19:44:29 -0800 Subject: [PATCH 080/149] [SPARK-11583] [CORE] MapStatus Using RoaringBitmap More Properly This PR upgrade the version of RoaringBitmap to 0.5.10, to optimize the memory layout, will be much smaller when most of blocks are empty. This PR is based on #9661 (fix conflicts), see all of the comments at https://github.com/apache/spark/pull/9661 . Author: Kent Yao Author: Davies Liu Author: Charles Allen Closes #9746 from davies/roaring_mapstatus. --- .../apache/spark/scheduler/MapStatus.scala | 5 +-- .../spark/serializer/KryoSerializer.scala | 6 ++-- .../spark/scheduler/MapStatusSuite.scala | 31 +++++++++++++++++++ pom.xml | 2 +- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1efce124c0a6b..b2e9a97129f08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -122,8 +122,7 @@ private[spark] class CompressedMapStatus( /** * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, - * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap - * is compressed. + * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks @@ -194,6 +193,8 @@ private[spark] object HighlyCompressedMapStatus { } else { 0 } + emptyBlocks.trim() + emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1bcb3175a3016..d5ba690ed04be 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{EOFException, IOException, InputStream, OutputStream, DataInput, DataOutput} +import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable @@ -25,9 +25,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap @@ -38,8 +38,8 @@ import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index b8e466fab4506..15c8de61b8240 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.roaringbitmap.RoaringBitmap import scala.util.Random @@ -97,4 +98,34 @@ class MapStatusSuite extends SparkFunSuite { val buf = ser.newInstance().serialize(status) ser.newInstance().deserialize[MapStatus](buf) } + + test("RoaringBitmap: runOptimize succeeded") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 != 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 > size2) + assert(success) + } + + test("RoaringBitmap: runOptimize failed") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 == 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 === size2) + assert(!success) + } } diff --git a/pom.xml b/pom.xml index 2a8a445057174..940e2d8740bf1 100644 --- a/pom.xml +++ b/pom.xml @@ -637,7 +637,7 @@ org.roaringbitmap RoaringBitmap - 0.4.5 + 0.5.11 commons-net From 98be8169f07eb0f1b8f01776c71d0e1ed3d5e4d5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 17 Nov 2015 19:50:02 -0800 Subject: [PATCH 081/149] [SPARK-11737] [SQL] Fix serialization of UTF8String with Kyro The default implementation of serialization UTF8String with Kyro may be not correct (BYTE_ARRAY_OFFSET could be different across JVM) Author: Davies Liu Closes #9704 from davies/kyro_string. --- unsafe/pom.xml | 4 ++++ .../apache/spark/unsafe/types/UTF8String.java | 24 +++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/unsafe/pom.xml b/unsafe/pom.xml index caf1f77890b58..a1c1111364ee8 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -36,6 +36,10 @@ + + com.twitter + chill_${scala.binary.version} + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b7aecb5102ba6..4bd3fd7772079 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -24,6 +24,11 @@ import java.util.Arrays; import java.util.Map; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -38,9 +43,9 @@ *

    * Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Externalizable { +public final class UTF8String implements Comparable, Externalizable, KryoSerializable { - // These are only updated by readExternal() + // These are only updated by readExternal() or read() @Nonnull private Object base; private long offset; @@ -1003,4 +1008,19 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept in.readFully((byte[]) base); } + @Override + public void write(Kryo kryo, Output out) { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + @Override + public void read(Kryo kryo, Input in) { + this.offset = BYTE_ARRAY_OFFSET; + this.numBytes = in.readInt(); + this.base = new byte[numBytes]; + in.read((byte[]) base); + } + } From 91f4b6f2db12650dfc33a576803ba8aeccf935dd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 17 Nov 2015 21:40:58 -0800 Subject: [PATCH 082/149] [SPARK-11797][SQL] collect, first, and take should use encoders for serialization They were previously using Spark's default serializer for serialization. Author: Reynold Xin Closes #9787 from rxin/SPARK-11797. --- .../scala/org/apache/spark/sql/Dataset.scala | 17 +++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 30 ++++++++++++++++++- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd01dd4dc5799..718ed812dd64c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -199,7 +200,6 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - encoderFor[T].assertUnresolved() new Dataset[U]( sqlContext, MapPartitions[T, U]( @@ -519,7 +519,7 @@ class Dataset[T] private[sql]( * Returns the first element in this [[Dataset]]. * @since 1.6.0 */ - def first(): T = rdd.first() + def first(): T = take(1).head /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -530,7 +530,14 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collect(): Array[T] = rdd.collect() + def collect(): Array[T] = { + // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders + // to convert the rows into objects of type T. + val tEnc = resolvedTEncoder + val input = queryExecution.analyzed.output + val bound = tEnc.bind(input) + queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + } /** * Returns an array that contains all the elements in this [[Dataset]]. @@ -541,7 +548,7 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = collect().toSeq.asJava /** * Returns the first `num` elements of this [[Dataset]] as an array. @@ -551,7 +558,7 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def take(num: Int): Array[T] = rdd.take(num) + def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() /** * Returns the first `num` elements of this [[Dataset]] as an array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a3922340ccc9a..ea29428c55088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.{ObjectInput, ObjectOutput, Externalizable} + import scala.language.postfixOps import org.apache.spark.sql.functions._ @@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext case class ClassData(a: String, b: Int) +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + test("collect, first, and take should use encoders for serialization") { + val item = NonSerializableCaseClass("abcd") + val ds = Seq(item).toDS() + assert(ds.collect().head == item) + assert(ds.collectAsList().get(0) == item) + assert(ds.first() == item) + assert(ds.take(1).head == item) + assert(ds.takeAsList(1).get(0) == item) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( @@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ignore("Dataset should set the resolved encoders internally for maps") { // TODO: Enable this once we fix SPARK-11793. + // We inject a group by here to make sure this test case is future proof + // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() .map(c => ClassData(c.a, c.b + 1)) .groupBy(p => p).count() @@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy function, fatMap") { + test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } From 8fb775ba874dd0488667bf299a7b49760062dc00 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 17 Nov 2015 22:13:15 -0800 Subject: [PATCH 083/149] [SPARK-11755][R] SparkR should export "predict" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bug described at [SPARK-11755](https://issues.apache.org/jira/browse/SPARK-11755), after exporting ```predict``` we can both get the help information from the SparkR and base R package like the following: ```Java > help(predict) Help on topic ‘predict’ was found in the following packages: Package Library SparkR /Users/yanboliang/data/trunk2/spark/R/lib stats /Library/Frameworks/R.framework/Versions/3.2/Resources/library Choose one 1: Make predictions from a model {SparkR} 2: Model Predictions {stats} ``` Author: Yanbo Liang Closes #9732 from yanboliang/spark-11755. --- R/pkg/R/generics.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 612e639f8ad99..afdeffc2abd83 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1054,6 +1054,10 @@ setGeneric("year", function(x) { standardGeneric("year") }) #' @export setGeneric("glm") +#' @rdname predict +#' @export +setGeneric("predict", function(object, ...) { standardGeneric("predict") }) + #' @rdname rbind #' @export setGeneric("rbind", signature = "...") From 446738e51fcda50cf2dc44123ff6bf12a1611dc0 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 17 Nov 2015 22:47:53 -0800 Subject: [PATCH 084/149] [SPARK-11761] Prevent the call to StreamingContext#stop() in the listener bus's thread See discussion toward the tail of https://github.com/apache/spark/pull/9723 From zsxwing : ``` The user should not call stop or other long-time work in a listener since it will block the listener thread, and prevent from stopping SparkContext/StreamingContext. I cannot see an approach since we need to stop the listener bus's thread before stopping SparkContext/StreamingContext totally. ``` Proposed solution is to prevent the call to StreamingContext#stop() in the listener bus's thread. Author: tedyu Closes #9741 from tedyu/master. --- .../spark/util/AsynchronousListenerBus.scala | 46 +++++++++++-------- .../spark/streaming/StreamingContext.scala | 6 ++- .../streaming/StreamingListenerSuite.scala | 34 ++++++++++++++ 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index c20627b056bef..6c1fca71f2281 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.util.DynamicVariable import org.apache.spark.SparkContext @@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { + AsynchronousListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() self.synchronized { - processingEvent = false + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } } } } @@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri */ def onDropEvent(event: E): Unit } + +private[spark] object AsynchronousListenerBus { + /* Allows for Context to check whether stop() call is made within listener thread + */ + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 97113835f3bd0..aee172a4f549a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} +import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -693,6 +693,10 @@ class StreamingContext private[streaming] ( */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop StreamingContext within listener thread of" + + " AsynchronousListenerBus") + } synchronized { try { state match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 5dc0472c7770c..df4575ab25aad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver @@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + test("don't call ssc.stop in listener") { + ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + startStreamingContextAndCallStop(ssc) + } + test("onBatchCompleted with successful batch") { ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) @@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { assert(failureReasons(1).contains("This is another failed job")) } + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { + val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) + _ssc.addStreamingListener(contextStoppingCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + _ssc.stop() + assert(contextStoppingCollector.sparkExSeen) + } + private def startStreamingContextAndCollectFailureReasons( _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { val failureReasonsCollector = new FailureReasonsCollector() @@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener { } } } +/** + * A StreamingListener that calls StreamingContext.stop(). + */ +class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { + @volatile var sparkExSeen = false + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} From 67a5132c21bc8338adbae80b33b85e8fa0ddda34 Mon Sep 17 00:00:00 2001 From: RoyGaoVLIS Date: Tue, 17 Nov 2015 23:00:49 -0800 Subject: [PATCH 085/149] [SPARK-7013][ML][TEST] Add unit test for spark.ml StandardScaler I have added unit test for ML's StandardScaler By comparing with R's output, please review for me. Thx. Author: RoyGaoVLIS Closes #6665 from RoyGao/7013. --- .../ml/feature/StandardScalerSuite.scala | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000000..879a3ae875004 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ + + @transient var data: Array[Vector] = _ + @transient var resWithStd: Array[Vector] = _ + @transient var resWithMean: Array[Vector] = _ + @transient var resWithBoth: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.dense(-2.0, 2.3, 0.0), + Vectors.dense(0.0, -5.1, 1.0), + Vectors.dense(1.7, -0.6, 3.3) + ) + resWithMean = Array( + Vectors.dense(-1.9, 3.433333333333, -1.433333333333), + Vectors.dense(0.1, -3.966666666667, -0.433333333333), + Vectors.dense(1.8, 0.533333333333, 1.866666666667) + ) + resWithStd = Array( + Vectors.dense(-1.079898494312, 0.616834091415, 0.0), + Vectors.dense(0.0, -1.367762550529, 0.590968109266), + Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579) + ) + resWithBoth = Array( + Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497), + Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682), + Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631) + ) + } + + def assertResult(dataframe: DataFrame): Unit = { + dataframe.select("standarded_features", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") + } + } + + test("Standardization with default parameter") { + val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + + val standardscaler0 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .fit(df0) + + assertResult(standardscaler0.transform(df0)) + } + + test("Standardization with setter") { + val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val standardscaler1 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(true) + .setWithStd(true) + .fit(df1) + + val standardscaler2 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(true) + .setWithStd(false) + .fit(df2) + + val standardscaler3 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(false) + .setWithStd(false) + .fit(df3) + + assertResult(standardscaler1.transform(df1)) + assertResult(standardscaler2.transform(df2)) + assertResult(standardscaler3.transform(df3)) + } +} From 2f191c66b668fc97f82f44fd8336b6a4488c2f5d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 17 Nov 2015 23:14:05 -0800 Subject: [PATCH 086/149] [SPARK-11643] [SQL] parse year with leading zero Support the years between 0 <= year < 1000 Author: Davies Liu Closes #9701 from davies/leading_zero. --- .../sql/catalyst/util/DateTimeUtils.scala | 20 +++++++++++++++++-- .../catalyst/util/DateTimeUtilsSuite.scala | 17 +++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 8fb3f41f1bd6a..17a5527f3fb29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -241,6 +241,10 @@ object DateTimeUtils { i += 3 } else if (i < 2) { if (b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -308,13 +312,17 @@ object DateTimeUtils { } segments(i) = currentSegmentValue + if (!justTime && i == 0 && j != 4) { + // year should have exact four digits + return None + } while (digitsMilli < 6) { segments(6) *= 10 digitsMilli += 1 } - if (!justTime && (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || + if (!justTime && (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { return None } @@ -368,6 +376,10 @@ object DateTimeUtils { while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { val b = bytes(j) if (i < 2 && b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -381,8 +393,12 @@ object DateTimeUtils { } j += 1 } + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue - if (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || + if (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31) { return None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 60d45422bc9b8..faca128badfd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -110,6 +110,10 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) + c.set(1, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(stringToDate(UTF8String.fromString("0001")).get === + millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) @@ -134,11 +138,15 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015")).isEmpty) } test("string to time") { // Tests with UTC. - var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(Calendar.MILLISECOND, 0) c.set(1900, 0, 1, 0, 0, 0) @@ -174,9 +182,9 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) - c.set(2015, 0, 1, 0, 0, 0) + c.set(1, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("0001")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) @@ -319,6 +327,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("00238")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) @@ -326,6 +335,8 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("02015-01-18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("015-01-18")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) assert(stringToTimestamp( From 9154f89befb7a33d4853cea95efd7dc6b25d033b Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 17 Nov 2015 23:44:06 -0800 Subject: [PATCH 087/149] [SPARK-11728] Replace example code in ml-ensembles.md using include_example JIRA issue https://issues.apache.org/jira/browse/SPARK-11728. The ml-ensembles.md file contains `OneVsRestExample`. Instead of writing new code files of two `OneVsRestExample`s, I use two existing files in the examples directory, they are `OneVsRestExample.scala` and `JavaOneVsRestExample.scala`. Author: Xusen Yin Closes #9716 from yinxusen/SPARK-11728. --- docs/ml-ensembles.md | 754 +----------------- ...aGradientBoostedTreeClassifierExample.java | 102 +++ ...vaGradientBoostedTreeRegressorExample.java | 90 +++ .../examples/ml/JavaOneVsRestExample.java | 4 + .../ml/JavaRandomForestClassifierExample.java | 101 +++ .../ml/JavaRandomForestRegressorExample.java | 90 +++ ...radient_boosted_tree_classifier_example.py | 77 ++ ...gradient_boosted_tree_regressor_example.py | 74 ++ .../ml/random_forest_classifier_example.py | 77 ++ .../ml/random_forest_regressor_example.py | 74 ++ ...GradientBoostedTreeClassifierExample.scala | 97 +++ .../GradientBoostedTreeRegressorExample.scala | 85 ++ .../spark/examples/ml/OneVsRestExample.scala | 4 + .../ml/RandomForestClassifierExample.scala | 97 +++ .../ml/RandomForestRegressorExample.scala | 84 ++ 15 files changed, 1070 insertions(+), 740 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java create mode 100644 examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py create mode 100644 examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py create mode 100644 examples/src/main/python/ml/random_forest_classifier_example.py create mode 100644 examples/src/main/python/ml/random_forest_regressor_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index ce15f5e6466ec..f6c3c30d5334f 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -115,194 +115,21 @@ We use two feature transformers to prepare the data; these help index categories Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.RandomForestClassifier -import org.apache.spark.ml.classification.RandomForestClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a RandomForest model. -val rf = new RandomForestClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setNumTrees(10) - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and forest in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] -println("Learned classification forest model:\n" + rfModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %}

    Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.RandomForestClassifier; -import org.apache.spark.ml.classification.RandomForestClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a RandomForest model. -RandomForestClassifier rf = new RandomForestClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and forest in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -RandomForestClassificationModel rfModel = - (RandomForestClassificationModel)(model.stages()[2]); -System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %}
    Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and forest in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) - -rfModel = model.stages[2] -print rfModel # summary only -{% endhighlight %} +{% include_example python/ml/random_forest_classifier_example.py %}
    @@ -316,167 +143,21 @@ We use a feature transformer to index categorical features, adding metadata to t Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.RandomForestRegressor -import org.apache.spark.ml.regression.RandomForestRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a RandomForest model. -val rf = new RandomForestRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and forest in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, rf)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] -println("Learned regression forest model:\n" + rfModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %}
    Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.RandomForestRegressionModel; -import org.apache.spark.ml.regression.RandomForestRegressor; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a RandomForest model. -RandomForestRegressor rf = new RandomForestRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and forest in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, rf}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -RandomForestRegressionModel rfModel = - (RandomForestRegressionModel)(model.stages()[1]); -System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %}
    Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import RandomForestRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -rf = RandomForestRegressor(featuresCol="indexedFeatures") - -# Chain indexer and forest in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, rf]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -rfModel = model.stages[1] -print rfModel # summary only -{% endhighlight %} +{% include_example python/ml/random_forest_regressor_example.py %}
    @@ -560,194 +241,21 @@ We use two feature transformers to prepare the data; these help index categories Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.GBTClassifier -import org.apache.spark.ml.classification.GBTClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a GBT model. -val gbt = new GBTClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10) - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and GBT in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] -println("Learned classification GBT model:\n" + gbtModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %}
    Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.GBTClassifier; -import org.apache.spark.ml.classification.GBTClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a GBT model. -GBTClassifier gbt = new GBTClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and GBT in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -GBTClassificationModel gbtModel = - (GBTClassificationModel)(model.stages()[2]); -System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %}
    Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GBT model. -gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) - -# Chain indexers and GBT in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) - -gbtModel = model.stages[2] -print gbtModel # summary only -{% endhighlight %} +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %}
    @@ -761,168 +269,21 @@ be true in general. Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.GBTRegressor -import org.apache.spark.ml.regression.GBTRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a GBT model. -val gbt = new GBTRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10) - -// Chain indexer and GBT in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, gbt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] -println("Learned regression GBT model:\n" + gbtModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %}
    Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.GBTRegressionModel; -import org.apache.spark.ml.regression.GBTRegressor; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a GBT model. -GBTRegressor gbt = new GBTRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10); - -// Chain indexer and GBT in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, gbt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -GBTRegressionModel gbtModel = - (GBTRegressionModel)(model.stages()[1]); -System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %}
    Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import GBTRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GBT model. -gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) - -# Chain indexer and GBT in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, gbt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -gbtModel = model.stages[1] -print gbtModel # summary only -{% endhighlight %} +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %}
    @@ -945,100 +306,13 @@ The example below demonstrates how to load the Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. -{% highlight scala %} -import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, SQLContext} - -val sqlContext = new SQLContext(sc) - -// parse data into dataframe -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) - -// instantiate multiclass learner and train -val ovr = new OneVsRest().setClassifier(new LogisticRegression) - -val ovrModel = ovr.fit(train) - -// score model on test data -val predictions = ovrModel.transform(test).select("prediction", "label") -val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)} - -// compute confusion matrix -val metrics = new MulticlassMetrics(predictionsAndLabels) -println(metrics.confusionMatrix) - -// the Iris DataSet has three classes -val numClasses = 3 - -println("label\tfpr\n") -(0 until numClasses).foreach { index => - val label = index.toDouble - println(label + "\t" + metrics.falsePositiveRate(label)) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
    Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. -{% highlight java %} -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.OneVsRest; -import org.apache.spark.ml.classification.OneVsRestModel; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_multiclass_classification_data.txt"); - -DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); -DataFrame train = splits[0]; -DataFrame test = splits[1]; - -// instantiate the One Vs Rest Classifier -OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); - -// train the multiclass model -OneVsRestModel ovrModel = ovr.fit(train.cache()); - -// score the model on test data -DataFrame predictions = ovrModel - .transform(test) - .select("prediction", "label"); - -// obtain metrics -MulticlassMetrics metrics = new MulticlassMetrics(predictions); -Matrix confusionMatrix = metrics.confusionMatrix(); - -// output the Confusion Matrix -System.out.println("Confusion Matrix"); -System.out.println(confusionMatrix); - -// compute the false positive rate per label -System.out.println(); -System.out.println("label\tfpr\n"); - -// the Iris DataSet has three classes -int numClasses = 3; -for (int index = 0; index < numClasses; index++) { - double label = (double) index; - System.out.print(label); - System.out.print("\t"); - System.out.print(metrics.falsePositiveRate(label)); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java new file mode 100644 index 0000000000000..848fe6566c1ec --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.GBTClassificationModel; +import org.apache.spark.ml.classification.GBTClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a GBT model. + GBTClassifier gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and GBT in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]); + System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java new file mode 100644 index 0000000000000..1f67b0842db0d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.GBTRegressionModel; +import org.apache.spark.ml.regression.GBTRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a GBT model. + GBTRegressor gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Chain indexer and GBT in a Pipeline + Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + GBTRegressionModel gbtModel = (GBTRegressionModel)(model.stages()[1]); + System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index f0d92a56bee73..42374e77acf01 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -21,6 +21,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +// $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; @@ -31,6 +32,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; +// $example off$ /** * An example runner for Multiclass to Binary Reduction with One Vs Rest. @@ -61,6 +63,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); + // $example on$ // configure the base classifier LogisticRegression classifier = new LogisticRegression() .setMaxIter(params.maxIter) @@ -125,6 +128,7 @@ public static void main(String[] args) { System.out.println(confusionMatrix); System.out.println(); System.out.println(results); + // $example off$ jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java new file mode 100644 index 0000000000000..5a62496660290 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.RandomForestClassificationModel; +import org.apache.spark.ml.classification.RandomForestClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a RandomForest model. + RandomForestClassifier rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]); + System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java new file mode 100644 index 0000000000000..05782a0724a77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.RandomForestRegressionModel; +import org.apache.spark.ml.regression.RandomForestRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a RandomForest model. + RandomForestRegressor rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, rf}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]); + System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py new file mode 100644 index 0000000000000..028497651fbf9 --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Classifier Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) + + # Chain indexers and GBT in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + gbtModel = model.stages[2] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py new file mode 100644 index 0000000000000..4246e133a9030 --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Regressor Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) + + # Chain indexer and GBT in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, gbt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + gbtModel = model.stages[1] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py new file mode 100644 index 0000000000000..b3530d4f41c8e --- /dev/null +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classifier Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and forest in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + rfModel = model.stages[2] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_regressor_example.py b/examples/src/main/python/ml/random_forest_regressor_example.py new file mode 100644 index 0000000000000..b59c7c9414841 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_regressor_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Regressor Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import RandomForestRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestRegressor(featuresCol="indexedFeatures") + + # Chain indexer and forest in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, rf]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + rfModel = model.stages[1] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala new file mode 100644 index 0000000000000..474af7db4b49b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ + +object GradientBoostedTreeClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] + println("Learned classification GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala new file mode 100644 index 0000000000000..da1cd9c2ce525 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +// $example off$ + +object GradientBoostedTreeRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Chain indexer and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, gbt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] + println("Learned regression GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 8e4f1b09a24b5..b46faea5713fb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -23,12 +23,14 @@ import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} import scopt.OptionParser import org.apache.spark.{SparkContext, SparkConf} +// $example on$ import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.DataFrame +// $example off$ import org.apache.spark.sql.SQLContext /** @@ -112,6 +114,7 @@ object OneVsRestExample { val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) + // $example on$ val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { @@ -172,6 +175,7 @@ object OneVsRestExample { println("label\tfpr") println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala new file mode 100644 index 0000000000000..e79176ca6ca1c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ + +object RandomForestClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setNumTrees(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] + println("Learned classification forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala new file mode 100644 index 0000000000000..acec1437a1af5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +// $example off$ + +object RandomForestRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, rf)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] + println("Learned regression forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From 8019f66df5c65e21d6e4e7e8fbfb7d0471ba3e37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 17 Nov 2015 23:51:05 -0800 Subject: [PATCH 088/149] [SPARK-10186][SQL][FOLLOW-UP] simplify test Author: Wenchen Fan Closes #9783 from cloud-fan/postgre. --- .../org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 2e18d0a2baa1c..6eb6b3391a4a4 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -88,7 +88,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => - Column(If(Literal(true), Literal(null), a)).as(a.name) + Column(Literal.create(null, a.dataType)).as(a.name) }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } } From 5e2b44474c2b838bebeffe5ba5cd72961b0cd31e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 00:09:29 -0800 Subject: [PATCH 089/149] [SPARK-11802][SQL] Kryo-based encoder for opaque types in Datasets I also found a bug with self-joins returning incorrect results in the Dataset API. Two test cases attached and filed SPARK-11803. Author: Reynold Xin Closes #9789 from rxin/SPARK-11802. --- .../scala/org/apache/spark/sql/Encoder.scala | 31 +++++++- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../catalyst/encoders/ProductEncoder.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 69 +++++++++++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 18 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 6 ++ .../org/apache/spark/sql/GroupedDataset.scala | 1 - .../org/apache/spark/sql/DatasetSuite.scala | 70 +++++++++++++++---- 8 files changed, 178 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c8b017e251637..79c2255641c06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql -import scala.reflect.ClassTag +import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.types._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable { def clsTag: ClassTag[T] } +/** + * Methods for creating encoders. + */ object Encoders { + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = { + val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) + val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq(ser), + fromRowExpression = deser, + clsTag = classTag[T] + ) + } + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9a1a8f5cbbdc3..b977f278c5b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -161,7 +161,9 @@ case class ExpressionEncoder[T]( @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) - private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val inputRow = new GenericMutableRow(1) @transient private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 414adb21168ed..55c4ee11b20f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -230,7 +230,7 @@ object ProductEncoder { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException(s"Encoder for type $other is not supported") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 5cd19de68391c..489c6126f8cd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import scala.language.existentials +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData - -import scala.language.existentials - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy """ } } + +/** Serializes an input object using Kryo serializer. */ +case class SerializeWithKryo(child: Expression) extends UnaryExpression { + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $kryo.serialize(${input.value}, null).array(); + } + """ + } + + override def dataType: DataType = BinaryType +} + +/** + * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit + * parameter because TreeNode cannot copy implicit parameters. + */ +case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.javaType(dataType)}) + $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + } + """ + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 55821c4370684..2729db84897a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} +import org.apache.spark.sql.Encoders class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") @@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), FlatEncoder[Map[Int, Map[String, Int]]], "map of map") + + // Kryo encoders + encodeDecodeTest( + "hello", + Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + "kryo string") + encodeDecodeTest( + new NotJavaSerializable(15), + Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + "kryo object serialization") +} + + +class NotJavaSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[NotJavaSerializable].value + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 718ed812dd64c..817c20fdbb9f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -147,6 +147,12 @@ class Dataset[T] private[sql]( } } + /** + * Returns the number of elements in the [[Dataset]]. + * @since 1.6.0 + */ + def count(): Long = toDF().count() + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 467cd42b9b8dc..c66162ee2148a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ea29428c55088..a522894c374f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,21 +24,6 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -case class ClassData(a: String, b: Int) - -/** - * A class used to test serialization using encoders. This class throws exceptions when using - * Java serialization -- so the only way it can be "serialized" is through our encoders. - */ -case class NonSerializableCaseClass(value: String) extends Externalizable { - override def readExternal(in: ObjectInput): Unit = { - throw new UnsupportedOperationException - } - - override def writeExternal(out: ObjectOutput): Unit = { - throw new UnsupportedOperationException - } -} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } + ignore("self join") { + val ds = Seq("1", "2").toDS().as("a") + val joined = ds.joinWith(ds, lit(true)) + checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + } + test("toString") { val ds = Seq((1, 2)).toDS() assert(ds.toString == "[_1: int, _2: int]") } + + test("kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((KryoData(1), 1L), (KryoData(2), 1L))) + } + + ignore("kryo encoder self join") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (KryoData(1), KryoData(1)), + (KryoData(1), KryoData(2)), + (KryoData(2), KryoData(1)), + (KryoData(2), KryoData(2)))) + } +} + + +case class ClassData(a: String, b: Int) + +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + +/** Used to test Kryo encoder. */ +class KryoData(val a: Int) { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[KryoData].a + } + override def hashCode: Int = a + override def toString: String = s"KryoData($a)" +} + +object KryoData { + def apply(a: Int): KryoData = new KryoData(a) } From 1714350bddd78cd1398e1a816f675ab729001081 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 00:42:52 -0800 Subject: [PATCH 090/149] [SPARK-11792][SQL] SizeEstimator cannot provide a good size estimation of UnsafeHashedRelations https://issues.apache.org/jira/browse/SPARK-11792 Right now, SizeEstimator will "think" a small UnsafeHashedRelation is several GBs. Author: Yin Huai Closes #9788 from yhuai/SPARK-11792. --- .../spark/memory/TaskMemoryManager.java | 3 +++ .../org/apache/spark/util/SizeEstimator.scala | 26 ++++++++++++++++--- .../spark/util/SizeEstimatorSuite.scala | 22 ++++++++++++++++ .../sql/execution/joins/HashedRelation.scala | 10 +++++-- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 5f743b28857b4..d31eb449eb82e 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -215,6 +215,9 @@ public void showMemoryUsage() { logger.info( "{} bytes of memory were used by task {} but are not associated with specific consumers", memoryNotAccountedFor, taskAttemptId); + logger.info( + "{} bytes of memory are used for execution and {} bytes of memory are used for storage", + memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed()); } } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 23ee4eff0881b..c3a2675ee5f45 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -31,6 +31,16 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet +/** + * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation. + * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. + * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size + * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + */ +private[spark] trait SizeEstimation { + def estimatedSize: Option[Long] +} + /** * :: DeveloperApi :: * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in @@ -199,10 +209,18 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) + val estimatedSize = obj match { + case s: SizeEstimation => s.estimatedSize + case _ => None + } + if (estimatedSize.isDefined) { + state.size += estimatedSize.get + } else { + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 20550178fb1bd..9b6261af123e6 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,6 +60,18 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } +class DummyClass8 extends SizeEstimation { + val x: Int = 0 + + override def estimatedSize: Option[Long] = Some(2015) +} + +class DummyClass9 extends SizeEstimation { + val x: Int = 0 + + override def estimatedSize: Option[Long] = None +} + class SizeEstimatorSuite extends SparkFunSuite with BeforeAndAfterEach @@ -214,4 +226,14 @@ class SizeEstimatorSuite // Class should be 32 bytes on s390x if recognised as 64 bit platform assertResult(32)(SizeEstimator.estimate(new DummyClass7)) } + + test("SizeEstimation can provide the estimated size") { + // DummyClass8 provides its size estimation. + assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) + assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) + + // DummyClass9 does not provide its size estimation. + assertResult(16)(SizeEstimator.estimate(new DummyClass9)) + assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8abb1ba463c..49ae09bf53782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.Utils +import org.apache.spark.util.{SizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -189,7 +189,9 @@ private[execution] object HashedRelation { */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) - extends HashedRelation with Externalizable { + extends HashedRelation + with SizeEstimation + with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -215,6 +217,10 @@ private[joins] final class UnsafeHashedRelation( } } + override def estimatedSize: Option[Long] = { + Option(binaryMap).map(_.getTotalMemoryConsumption) + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] From b8f4379ba1c5c1a8f3b4c88bd97031dc8ad2dfea Mon Sep 17 00:00:00 2001 From: somideshmukh Date: Wed, 18 Nov 2015 08:51:01 +0000 Subject: [PATCH 091/149] [SPARK-10946][SQL] JDBC - Use Statement.executeUpdate instead of PreparedStatement.executeUpdate for DDLs New changes with JDBCRDD Author: somideshmukh Closes #9733 from somideshmukh/SomilBranch-1.1. --- .../src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e63a4d5e8b10b..03867beb78224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -297,7 +297,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() + conn.createStatement.executeUpdate(sql) } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 32d28e59377a1..7375a5c09123f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -55,7 +55,7 @@ object JdbcUtils extends Logging { * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + conn.createStatement.executeUpdate(s"DROP TABLE $table") } /** From e62820c85fe02c70f9ed51b2e68d41ff8cfecd40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Wed, 18 Nov 2015 08:57:58 +0000 Subject: [PATCH 092/149] [SPARK-6541] Sort executors by ID (numeric) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "Force" the executor ID sort with Int. Author: Jean-Baptiste Onofré Closes #9165 from jbonofre/SPARK-6541. --- .../org/apache/spark/ui/static/sorttable.js | 2 +- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dde6069000bc4..a73d9a5cbc215 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -89,7 +89,7 @@ sorttable = { // make it clickable to sort headrow[i].sorttable_columnindex = i; headrow[i].sorttable_tbody = table.tBodies[0]; - dean_addEvent(headrow[i],"click", function(e) { + dean_addEvent(headrow[i],"click", sorttable.innerSortFunction = function(e) { if (this.className.search(/\bsorttable_sorted\b/) != -1) { // if we're already sorted by this column, just diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index be144f6065baa..1268f44596f8a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.jobs import scala.collection.mutable -import scala.xml.Node +import scala.xml.{Unparsed, Node} import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData @@ -52,7 +52,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - + @@ -89,6 +89,15 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {createExecutorTable()}
    Executor IDExecutor ID Address Task Time Total Tasks
    + } private def createExecutorTable() : Seq[Node] = { From 9631ca35275b0ce8a5219f975907ac36ed11f528 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Nov 2015 08:59:20 +0000 Subject: [PATCH 093/149] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Update to Commons Collections 3.2.2 to avoid any potential remote code execution vulnerability Author: Sean Owen Closes #9731 from srowen/SPARK-11652. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 940e2d8740bf1..ad849112ce76c 100644 --- a/pom.xml +++ b/pom.xml @@ -162,6 +162,8 @@ 3.1 3.4.1 + + 3.2.2 2.10.5 2.10 ${scala.version} @@ -475,6 +477,11 @@ commons-math3 ${commons.math3.version}
    + + org.apache.commons + commons-collections + ${commons.collections.version} + org.apache.ivy ivy From 1429e0a2b562469146b6fa06051c85a00092e5b8 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Wed, 18 Nov 2015 09:10:15 +0000 Subject: [PATCH 094/149] rmse was wrongly calculated It was multiplying with U instaed of dividing by U Author: Viveka Kulharia Closes #9771 from vivkul/patch-1. --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e94..205ca02962bee 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): From 3a6807fdf07b0e73d76502a6bd91ad979fde8b61 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 18 Nov 2015 08:18:54 -0800 Subject: [PATCH 095/149] =?UTF-8?q?[SPARK-11804]=20[PYSPARK]=20Exception?= =?UTF-8?q?=20raise=20when=20using=20Jdbc=20predicates=20opt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion in PySpark Author: Jeff Zhang Closes #9791 from zjffdu/SPARK-11804. --- python/pyspark/sql/readwriter.py | 10 +++++----- python/pyspark/sql/utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b8ddb9feba34..e8f0d7ec77035 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -131,9 +132,7 @@ def load(self, path=None, format=None, schema=None, **options): if type(path) == list: paths = path gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] + jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) return self._df(self._jreader.load(jpaths)) else: return self._df(self._jreader.load(path)) @@ -269,8 +268,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b891..b0a0373372d20 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr From a97d6f3a5861e9f2bbe36957e3b39f835f3e214c Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 18 Nov 2015 08:32:03 -0800 Subject: [PATCH 096/149] [SPARK-11281][SPARKR] Add tests covering the issue. The goal of this PR is to add tests covering the issue to ensure that is was resolved by [SPARK-11086](https://issues.apache.org/jira/browse/SPARK-11086). Author: zero323 Closes #9743 from zero323/SPARK-11281-tests. --- R/pkg/inst/tests/test_sparkSQL.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8ff06276599e2..87ab33f6384b1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", { df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) - l <- list(list(a=1, b=2), list(a=3, b=4)) + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) @@ -292,11 +292,15 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - ldf <- data.frame(row.names=1:2) + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) - expect_equivalent(ldf, collect(sdf)) + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) }) # For test map type and struct type in DataFrame From 224723e6a8b198ef45d6c5ca5d2f9c61188ada8f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 18 Nov 2015 08:41:45 -0800 Subject: [PATCH 097/149] [SPARK-11773][SPARKR] Implement collection functions in SparkR. Author: Sun Rui Closes #9764 from sun-rui/SPARK-11773. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 109 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 10 ++- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 10 +++ 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1bc..260c9edce62e0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -98,6 +98,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", "ascii", "asin", @@ -215,6 +216,7 @@ exportMethods("%in%", "sinh", "size", "skewness", + "sort_array", "soundex", "stddev", "stddev_pop", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd105ba5bc9bb..34177e3cdd94f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2198,4 +2198,4 @@ setMethod("coltypes", rTypes[naIndices] <- types[naIndices] rTypes - }) \ No newline at end of file + }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d0255a62f155..ff0f438045c14 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -373,22 +373,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -980,22 +964,6 @@ setMethod("sinh", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @rdname size -#' @name size -#' @family collection_funcs -#' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) - column(jc) - }) - #' skewness #' #' Aggregate function: returns the skewness of the values in a group. @@ -2365,3 +2333,80 @@ setMethod("rowNumber", jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") column(jc) }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index afdeffc2abd83..0dcd05438222b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @export setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1076,4 +1084,4 @@ setGeneric("with") #' @rdname coltypes #' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index db3b2c4bbd799..45c77a86c9582 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -635,4 +635,4 @@ assignNewEnv <- function(data) { assign(x = cols[i], value = data[, cols[i]], envir = env) } env -} \ No newline at end of file +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 87ab33f6384b1..d9a94faff7ac0 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -878,6 +878,16 @@ test_that("column functions", { df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) }) # test_that("column binary mathfunctions", { From 3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH 098/149] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c746..43c89b258f2fa 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2ee..f4965994d8277 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff529..bc72c3685e8c1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } From cffb899c4397ecccedbcc41e7cf3da91f953435a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:15:50 -0800 Subject: [PATCH 099/149] [SPARK-11803][SQL] fix Dataset self-join When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan Closes #9806 from cloud-fan/self-join. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 817c20fdbb9f3..b644f6ad3096d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -498,13 +498,17 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan + val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() } @@ -513,7 +517,7 @@ class Dataset[T] private[sql]( withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + joined.analyzed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a522894c374f9..198962b8fb750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } - ignore("self join") { + test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) @@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.groupBy(p => p).count().collect().toSeq == Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - ignore("kryo encoder self join") { + test("kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == Set( (KryoData(1), KryoData(1)), From 33b837333435ceb0c04d1f361a5383c4fe6a5a75 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:23:12 -0800 Subject: [PATCH 100/149] [SPARK-11725][SQL] correctly handle null inputs for UDF If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan Closes #9770 from cloud-fan/udf. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 ++++ .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++ .../sql/catalyst/ScalaReflectionSuite.scala | 17 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++ 6 files changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e8..38828e59a2152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -719,6 +719,15 @@ trait ScalaReflection { } } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes + } + def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. case obj: Boolean => BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..f00c451b5981a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -85,6 +85,8 @@ class Analyzer( extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -1063,6 +1065,34 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case plan => plan transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3388cc20a9803..03b89221ef2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737f..4ea410d492b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 65f09b46afae1..08586a97411ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 35cdab50bdec9..5a7f24684d1b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(df("*")), Row(1, "a")) checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) null else i * 2 + } + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } From dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:33:17 -0800 Subject: [PATCH 101/149] [SPARK-11795][SQL] combine grouping attributes into a single NamedExpression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan Closes #9792 from cloud-fan/agg. --- .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee2148a..3f84e22a1025b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb750..b6db583dfe01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { From 90a7519daaa7f4ee3be7c5a9aa244120811ff6eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 18 Nov 2015 11:35:41 -0800 Subject: [PATCH 102/149] [MINOR][BUILD] Ignore ensime cache Using ENSIME, I often have `.ensime_cache` polluting my source tree. This PR simply adds the cache directory to `.gitignore` Author: Jakob Odersky Closes #9708 from jodersky/master. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 08f2d8f7543f0..07524bc429e92 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log From 6f99522d13d8db9fcc767f7c3189557b9a53d283 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 11:49:12 -0800 Subject: [PATCH 103/149] [SPARK-11792] [SQL] [FOLLOW-UP] Change SizeEstimation to KnownSizeEstimation and make estimatedSize return Long instead of Option[Long] https://issues.apache.org/jira/browse/SPARK-11792 The main changes include: * Renaming `SizeEstimation` to `KnownSizeEstimation`. Hopefully this new name has more information. * Making `estimatedSize` return `Long` instead of `Option[Long]`. * In `UnsaveHashedRelation`, `estimatedSize` will delegate the work to `SizeEstimator` if we have not created a `BytesToBytesMap`. Since we will put `UnsaveHashedRelation` to `BlockManager`, it is generally good to let it provide a more accurate size estimation. Also, if we do not put `BytesToBytesMap` directly into `BlockerManager`, I feel it is not really necessary to make `BytesToBytesMap` extends `KnownSizeEstimation`. Author: Yin Huai Closes #9813 from yhuai/SPARK-11792-followup. --- .../org/apache/spark/util/SizeEstimator.scala | 30 ++++++++++--------- .../spark/util/SizeEstimatorSuite.scala | 14 ++------- .../sql/execution/joins/HashedRelation.scala | 12 +++++--- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5f45..09864e3f8392d 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af123e6..101610e38014e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf53782..aebfea5832402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = { From 94624eacb0fdbbe210894151a956f8150cdf527e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Nov 2015 11:53:28 -0800 Subject: [PATCH 104/149] [SPARK-11739][SQL] clear the instantiated SQLContext Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust. Author: Davies Liu Closes #9706 from davies/clear_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------ .../spark/sql/MultiSQLContextsSuite.scala | 5 ++--- .../execution/ExchangeCoordinatorSuite.scala | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39d..39471d2fb79a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1229,7 +1229,7 @@ class SQLContext private[sql]( // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() } }) @@ -1270,13 +1270,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1284,12 +1284,17 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b1..34c5c68fd1c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed15..b96d50a70b85c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { From 31921e0f0bd559d042148d1ea32f865fb3068f38 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Nov 2015 12:09:54 -0800 Subject: [PATCH 105/149] [SPARK-4557][STREAMING] Spark Streaming foreachRDD Java API method should accept a VoidFunction<...> Currently streaming foreachRDD Java API uses a function prototype requiring a return value of null. This PR deprecates the old method and uses VoidFunction to allow for more concise declaration. Also added VoidFunction2 to Java API in order to use in Streaming methods. Unit test is added for using foreachRDD with VoidFunction, and changes have been tested with Java 7 and Java 8 using lambdas. Author: Bryan Cutler Closes #9488 from BryanCutler/foreachRDD-VoidFunction-SPARK-4557. --- .../api/java/function/VoidFunction2.java | 27 ++++++++++++ .../apache/spark/streaming/Java8APISuite.java | 26 ++++++++++++ project/MimaExcludes.scala | 4 ++ .../streaming/api/java/JavaDStreamLike.scala | 24 ++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 41 ++++++++++++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 0000000000000..6c576ab678455 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + public void call(T1 v1, T2 v2) throws Exception; +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 163ae92c12c6d..4eee97bc89613 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,6 +28,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -360,6 +361,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eb70d27c34c20..bb45d1bb12146 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -142,6 +142,10 @@ object MimaExcludes { "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f15..84acec7d8e330 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -308,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -316,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e4..609bb4413b6b1 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -37,7 +37,9 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -45,7 +47,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -768,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { From a416e41e285700f861559d710dbf857405bfddf6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 12:50:29 -0800 Subject: [PATCH 106/149] [SPARK-11809] Switch the default Mesos mode to coarse-grained mode Based on my conversions with people, I believe the consensus is that the coarse-grained mode is more stable and easier to reason about. It is best to use that as the default rather than the more flaky fine-grained mode. Author: Reynold Xin Closes #9795 from rxin/SPARK-11809. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- docs/job-scheduling.md | 2 +- docs/running-on-mesos.md | 27 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b5645b08f92d4..ab374cb71286a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2710,7 +2710,7 @@ object SparkContext extends Logging { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a3c34cb6796fa..36327c6efeaf3 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5be208cf3461e..a197d0e373027 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. - -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} From 7c5b641808740ba5eed05ba8204cdbaf3fc579f5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 18 Nov 2015 12:53:22 -0800 Subject: [PATCH 107/149] [SPARK-10745][CORE] Separate configs between shuffle and RPC [SPARK-6028](https://issues.apache.org/jira/browse/SPARK-6028) uses network module to implement RPC. However, there are some configurations named with `spark.shuffle` prefix in the network module. This PR refactors them to make sure the user can control them in shuffle and RPC separately. The user can use `spark.rpc.*` to set the configuration for netty RPC. Author: Shixiong Zhu Closes #9481 from zsxwing/SPARK-10745. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../network/netty/SparkTransportConf.scala | 12 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 8 +-- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../spark/network/util/TransportConf.java | 65 ++++++++++++++----- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 2 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../network/TransportClientFactorySuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockResolverSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../shuffle/RetryingBlockFetcherSuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 2 +- 23 files changed, 84 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index a039d543c35e7..e8a1e35c3fc48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler, true) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6b..b0694e3c6c8af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d685..84833f59d7afe 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22c..3e0c497969502 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2de9b6a651692..7d08eae0b4871 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend( private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 39fadd8783518..cc5f933393adf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 05b1eed7f3bef..fadb8fe7ed0ab 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 661c706af32b1..ab0007fb78993 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -122,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a47..1c775bcb3d9c1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955a..115135d44adbd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed0..dc5fa1cee69bc 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -83,7 +83,7 @@ public static void setUp() throws Exception { fp.write(fileContent); fp.close(); - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d54..42955ef69235a 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -60,7 +60,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f01..8eb56bdd9846f 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -49,7 +49,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec71..00158fd081626 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -89,7 +89,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f447137419306..dac7d4a5b0a09 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -52,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -76,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -182,7 +182,7 @@ public void closeBlockClientsWithFactory() throws IOException { @Test public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { @Override public String get(String name) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4da..b146899670180 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -207,7 +207,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +242,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -368,7 +368,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e6810..1c2fa4d0d462c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -70,7 +70,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea46..a9958232a1d28 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478b..532d7ab8d01bd 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,7 +35,7 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb9..2095f41d79c16 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda94948..08ddb3755bd08 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,7 +39,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f9241094..3a6ef0d3f8476 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cfe..ba6d30a74c673 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -120,7 +120,7 @@ protected void serviceInit(Configuration conf) { registeredExecutorFile = findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); From 09ad9533d5760652de59fa4830c24cb8667958ac Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 18 Nov 2015 13:03:37 -0800 Subject: [PATCH 108/149] [SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa Closes #9705 from JihongMA/SPARK-11720. --- python/pyspark/sql/dataframe.py | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Kurtosis.scala | 9 +++++---- .../expressions/aggregate/Skewness.scala | 9 +++++---- .../expressions/aggregate/Stddev.scala | 18 ++++++++++++++---- .../expressions/aggregate/Variance.scala | 18 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a904..0dd75ba7ca820 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11eb1..d07d4c338cdfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1a51..c2bf2cb94116c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b82781..9411bcea2539a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b22543..eec79a9033e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da2805966..cf3a740305391 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d17623a4..71adf2148a403 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d1b7..6399b0165c4c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) From 045a4f045821dcf60442f0600c2df1b79bddb536 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [PATCH 109/149] [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../ml/regression/LinearRegression.scala | 77 ++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581983..ca55d5915e688 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fedad..2bdc0e184d734 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } From 2acdf10b1f3bb1242dba64efa798c672fde9f0d2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 13:16:31 -0800 Subject: [PATCH 110/149] [SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9786 from jkbradley/als-io. --- .../apache/spark/ml/recommendation/ALS.scala | 75 ++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 14 +++- .../spark/ml/recommendation/ALSSuite.scala | 78 ++++++++++++++++--- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a944..d92514d2e239e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -182,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with Writable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -220,8 +223,60 @@ class ALSModel private[ml] ( val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends Readable[ALSModel] { + + @Since("1.6.0") + override def read: Reader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = read.load(path) + + private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = render("rank" -> instance.rank) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private[recommendation] class ALSModelReader extends Reader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.recommendation.ALSModel" + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank: Int = metadata.extraMetadata match { + case Some(m: JValue) => + (m \ "rank").extract[Int] + case None => + throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + + s" ${metadata.metadataStr}") + } + + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,7 +309,7 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def copy(extra: ParamMap): ALS = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends Readable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +415,12 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def read: Reader[ALS] = new DefaultParamsReader[ALS] + + @Since("1.6.0") + override def load(path: String): ALS = read.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dddb72af5ba78..d8ce907af5323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ - def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("extraMetadata" -> extraMetadata) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] + * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] * @param metadataStr Full metadata file String (for debugging) */ case class Metadata( @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + extraMetadata: Option[JValue], metadataStr: String) /** @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" + val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b1..2c3fb84160dcb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } From e391abdf2cb6098a35347bd123b815ee9ac5b689 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 13:25:15 -0800 Subject: [PATCH 111/149] [SPARK-11813][MLLIB] Avoid serialization of vocab in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11813 I found the problem during training a large corpus. Avoid serialization of vocab in Word2Vec has 2 benefits. 1. Performance improvement for less serialization. 2. Increase the capacity of Word2Vec a lot. Currently in the fit of word2vec, the closure mainly includes serialization of Word2Vec and 2 global table. the main part of Word2vec is the vocab of size: vocab * 40 * 2 * 4 = 320 vocab 2 global table: vocab * vectorSize * 8. If vectorSize = 20, that's 160 vocab. Their sum cannot exceed Int.max due to the restriction of ByteArrayOutputStream. In any case, avoiding serialization of vocab helps decrease the size of the closure serialization, especially when vectorSize is small, thus to allow larger vocabulary. Actually there's another possible fix, make local copy of fields to avoid including Word2Vec in the closure. Let me know if that's preferred. Author: Yuhao Yang Closes #9803 from hhbyyh/w2vVocab. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358a..7ab0d89d23a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -145,8 +145,8 @@ class Word2Vec extends Serializable with Logging { private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) From e222d758499ad2609046cc1a2cc8afb45c5bccbb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:30:29 -0800 Subject: [PATCH 112/149] [SPARK-11684][R][ML][DOC] Update SparkR glm API doc, user guide and example codes This PR includes: * Update SparkR:::glm, SparkR:::summary API docs. * Update SparkR machine learning user guide and example codes to show: * supporting feature interaction in R formula. * summary for gaussian GLM model. * coefficients for binomial GLM model. mengxr Author: Yanbo Liang Closes #9727 from yanboliang/spark-11684. --- R/pkg/R/mllib.R | 18 +++++-- docs/sparkr.md | 50 ++++++++++++++++--- .../ml/regression/LinearRegression.scala | 3 ++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f23e1c7f1fce4..8d3b4388ae575 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr #' @rdname summary #' @export #' @examples diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c276..a744b76be7466 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,24 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ca55d5915e688..f7c44f0a51b8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,6 +145,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam From 603a721c21488e17c15c45ce1de893e6b3d02274 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:32:06 -0800 Subject: [PATCH 113/149] [SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol [SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang Closes #9811 from yanboliang/spark-11820. --- python/pyspark/ml/classification.py | 17 +++++++++-------- python/pyspark/ml/regression.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 603f2c7f798dc..4a2982e2047ff 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) >>> model.weights DenseVector([5.5...]) @@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() @@ -105,12 +106,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266bf..944e648ec8801 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. @@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +92,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From 54db79702513e11335c33bcf3a03c59e965e6f16 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Nov 2015 14:05:18 -0800 Subject: [PATCH 114/149] [SPARK-11544][SQL] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9652 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++++--- .../datasources/json/JsonSuite.scala | 36 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df63..f9465157c936d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178affe..f09b61e838159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 15:42:07 -0800 Subject: [PATCH 115/149] [SPARK-11810][SQL] Java-based encoder for opaque types in Datasets. This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin Closes #9802 from rxin/SPARK-11810. --- .../scala/org/apache/spark/sql/Encoder.scala | 41 +++++++++--- .../sql/catalyst/expressions/objects.scala | 67 ++++++++++++------- .../catalyst/encoders/FlatEncoderSuite.scala | 27 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++++- 4 files changed, 130 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 79c2255641c06..1ed5111440c80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - */ - def kryo[T: ClassTag]: Encoder[T] = { - val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) - val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq(ser), - fromRowExpression = deser, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 489c6126f8cd3..acf0da240051e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} +import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} @@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy } } -/** Serializes an input object using Kryo serializer. */ -case class SerializeWithKryo(child: Expression) extends UnaryExpression { +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $kryo.serialize(${input.value}, null).array(); + ${ev.value} = $serializer.serialize(${input.value}, null).array(); } """ } @@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression { } /** - * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit - * parameter because TreeNode cannot copy implicit parameters. + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.javaType(dataType)}) - $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 2729db84897a2..6e0322fb6e019 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { // Kryo encoders encodeDecodeTest( "hello", - Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + encoderFor(Encoders.kryo[String]), "kryo string") encodeDecodeTest( - new NotJavaSerializable(15), - Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object serialization") + + // Java encoders + encodeDecodeTest( + "hello", + encoderFor(Encoders.javaSerialization[String]), + "java string") + encodeDecodeTest( + new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), + "java object serialization") } +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} -class NotJavaSerializable(val value: Int) { +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[NotJavaSerializable].value + this.value == other.asInstanceOf[JavaSerializable].value } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b6db583dfe01f..89d964aa3e469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("kryo encoder") { + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - test("kryo encoder self join") { + test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(1)), (KryoData(2), KryoData(2)))) } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + ignore("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } } @@ -406,3 +425,16 @@ class KryoData(val a: Int) { object KryoData { def apply(a: Int): KryoData = new KryoData(a) } + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} From 7e987de1770f4ab3d54bc05db8de0a1ef035941d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 15:47:49 -0800 Subject: [PATCH 116/149] [SPARK-6787][ML] add read/write to estimators under ml.feature (1) Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng Closes #9798 from mengxr/SPARK-6787. --- .../spark/ml/feature/CountVectorizer.scala | 72 +++++++++++++++-- .../org/apache/spark/ml/feature/IDF.scala | 71 ++++++++++++++++- .../spark/ml/feature/MinMaxScaler.scala | 72 +++++++++++++++-- .../spark/ml/feature/StandardScaler.scala | 78 ++++++++++++++++++- .../spark/ml/feature/StringIndexer.scala | 70 +++++++++++++++-- .../ml/feature/CountVectorizerSuite.scala | 24 +++++- .../apache/spark/ml/feature/IDFSuite.scala | 19 ++++- .../spark/ml/feature/MinMaxScalerSuite.scala | 25 +++++- .../ml/feature/StandardScalerSuite.scala | 64 +++++++++++---- .../spark/ml/feature/StringIndexerSuite.scala | 19 ++++- 10 files changed, 467 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b85064..5ff9bfb7d1119 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { def this() = this(Identifiable.randomUID("cntVec")) @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object CountVectorizer extends Readable[CountVectorizer] { + + @Since("1.6.0") + override def read: Reader[CountVectorizer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) } /** @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: Writer = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends Readable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + + private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa0..53ad34ef12646 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { def this() = this(Identifiable.randomUID("idf")) @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IDF extends Readable[IDF] { + + @Since("1.6.0") + override def read: Reader[IDF] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) } /** @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with Writable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -117,4 +134,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: Writer = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends Readable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends Reader[IDFModel] { + + private val className = "org.apache.spark.ml.feature.IDFModel" + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b1727..24d964fae834e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object MinMaxScaler extends Readable[MinMaxScaler] { + + @Since("1.6.0") + override def read: Reader[MinMaxScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) } /** @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends Readable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + + private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e75..ab04e5418dd4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with Writable { def this() = this(Identifiable.randomUID("stdScal")) @@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StandardScaler extends Readable[StandardScaler] { + + @Since("1.6.0") + override def read: Reader[StandardScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) } /** @@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { + extends Model[StandardScalerModel] with StandardScalerParams with Writable { + + import StandardScalerModel._ /** Standard deviation of the StandardScalerModel */ val std: Vector = scaler.std @@ -112,6 +129,14 @@ class StandardScalerModel private[ml] ( /** Mean of the StandardScalerModel */ val mean: Vector = scaler.mean + /** Whether to scale to unit standard deviation. */ + @Since("1.6.0") + def getWithStd: Boolean = scaler.withStd + + /** Whether to center data with mean. */ + @Since("1.6.0") + def getWithMean: Boolean = scaler.withMean + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -138,4 +163,49 @@ class StandardScalerModel private[ml] ( val copied = new StandardScalerModel(uid, scaler) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends Readable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + + private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends Reader[StandardScalerModel] { + + private val className = "org.apache.spark.ml.feature.StandardScalerModel" + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = + sqlContext.read.parquet(dataPath) + .select("std", "mean", "withStd", "withMean") + .head() + // This is very likely to change in the future because withStd and withMean should be params. + val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) + val model = new StandardScalerModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f782a272d11db..f16f6afc002d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with Writable { def this() = this(Identifiable.randomUID("strIdx")) @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StringIndexer extends Readable[StringIndexer] { + + @Since("1.6.0") + override def read: Reader[StringIndexer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) } /** @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with Writable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -176,6 +193,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends Readable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends Reader[StringIndexerModel] { + + private val className = "org.apache.spark.ml.feature.StringIndexerModel" + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af0..9c9999017317d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429b..bc958c15857ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea34..09183fe65b722 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { val sqlContext = new SQLContext(sc) @@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 879a3ae875004..49a4b2efe0c29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ ) } - def assertResult(dataframe: DataFrame): Unit = { - dataframe.select("standarded_features", "expected").collect().foreach { + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "The vector value is not correct after standardization.") } } + test("params") { + ParamsSuite.checkParams(new StandardScaler) + val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) + ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + } + test("Standardization with default parameter") { val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") - val standardscaler0 = new StandardScaler() + val standardScaler0 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .fit(df0) - assertResult(standardscaler0.transform(df0)) + assertResult(standardScaler0.transform(df0)) } test("Standardization with setter") { @@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") - val standardscaler1 = new StandardScaler() + val standardScaler1 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(true) .fit(df1) - val standardscaler2 = new StandardScaler() + val standardScaler2 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(false) .fit(df2) - val standardscaler3 = new StandardScaler() + val standardScaler3 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(false) .setWithStd(false) .fit(df3) - assertResult(standardscaler1.transform(df1)) - assertResult(standardscaler2.transform(df2)) - assertResult(standardscaler3.transform(df3)) + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val oldModel = new feature.StandardScalerModel( + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) + val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + assert(newInstance.getWithStd === instance.getWithStd) + assert(newInstance.getWithMean === instance.getWithMean) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index be37bfb438833..749bfac747826 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,23 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + test("IndexToString params") { val idxToStr = new IndexToString() ParamsSuite.checkParams(idxToStr) @@ -175,7 +192,7 @@ class StringIndexerSuite assert(outSchema("output").dataType === StringType) } - test("read/write") { + test("IndexToString read/write") { val t = new IndexToString() .setInputCol("myInputCol") .setOutputCol("myOutputCol") From 3a9851936ddfe5bcb6a7f364d535fac977551f5d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 15:55:41 -0800 Subject: [PATCH 117/149] [SPARK-11649] Properly set Akka frame size in SparkListenerSuite test SparkListenerSuite's _"onTaskGettingResult() called when result fetched remotely"_ test was extremely slow (1 to 4 minutes to run) and recently became extremely flaky, frequently failing with OutOfMemoryError. The root cause was the fact that this was using `System.setProperty` to set the Akka frame size, which was not actually modifying the frame size. As a result, this test would allocate much more data than necessary. The fix here is to simply use SparkConf in order to configure the frame size. Author: Josh Rosen Closes #9822 from JoshRosen/SPARK-11649. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c936..84e545851f49e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } From c07a50b86254578625be777b1890ff95e832ac6e Mon Sep 17 00:00:00 2001 From: Derek Dagit Date: Wed, 18 Nov 2015 15:56:54 -0800 Subject: [PATCH 118/149] [SPARK-10930] History "Stages" page "duration" can be confusing Author: Derek Dagit Closes #9051 from d2r/spark-10930-ui-max-task-dur. --- .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b6009..2a1c3c1a50ec9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes From 4b117121900e5f242e7c8f46a69164385f0da7cc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 16:00:35 -0800 Subject: [PATCH 119/149] [SPARK-11495] Fix potential socket / file handle leaks that were found via static analysis The HP Fortify Opens Source Review team (https://www.hpfod.com/open-source-review-project) reported a handful of potential resource leaks that were discovered using their static analysis tool. We should fix the issues identified by their scan. Author: Josh Rosen Closes #9455 from JoshRosen/fix-potential-resource-leaks. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 ++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 38 +++++++++++-------- .../streaming/JavaCustomReceiver.java | 31 +++++++-------- .../network/ChunkFetchIntegrationSuite.java | 15 ++++++-- .../shuffle/TestShuffleDataContext.java | 32 ++++++++++------ .../spark/streaming/JavaReceiverAPISuite.java | 20 ++++++---- 6 files changed, 90 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 04694dc54418c..3387f9a4177ce 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,6 +273,7 @@ private void advanceToNextPage() { } } try { + Closeables.close(reader, /* swallowIOException = */ false); reader = spillWriters.getFirst().getReader(blockManager); recordsInPage = -1; } catch (IOException e) { @@ -318,6 +320,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 039e940a357ea..dcb13e6581e54 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,8 +20,7 @@ import java.io.*; import com.google.common.io.ByteStreams; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.io.Closeables; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -31,10 +30,8 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { - private final File file; private InputStream in; private DataInputStream din; @@ -52,11 +49,15 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { assert (file.length() > 0); - this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + try { + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } } @Override @@ -75,12 +76,7 @@ public void loadNext() throws IOException { ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { - in.close(); - if (!file.delete() && file.exists()) { - logger.warn("Unable to delete spill file {}", file.getPath()); - } - in = null; - din = null; + close(); } } @@ -103,4 +99,16 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } + + @Override + public void close() throws IOException { + if (in != null) { + try { + in.close(); + } finally { + in = null; + din = null; + } + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e6..4b50fbf59f80e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dc5fa1cee69bc..50a324e293386 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,10 +79,15 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); @@ -117,6 +123,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c7..7ac1ca128aed0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b97..7a8ef9d14784c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -121,14 +122,19 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); From a402c92c92b2e1c85d264f6077aec8f6d6a08270 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Nov 2015 16:08:06 -0800 Subject: [PATCH 120/149] [SPARK-11814][STREAMING] Add better default checkpoint duration DStream checkpoint interval is by default set at max(10 second, batch interval). That's bad for large batch intervals where the checkpoint interval = batch interval, and RDDs get checkpointed every batch. This PR is to set the checkpoint interval of trackStateByKey to 10 * batch duration. Author: Tathagata Das Closes #9805 from tdas/SPARK-11814. --- .../streaming/dstream/TrackStateDStream.scala | 13 ++++++ .../streaming/TrackStateByKeySuite.scala | 44 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 98e881e6ae115..0ada1111ce30a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ /** * :: Experimental :: @@ -120,6 +121,14 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Enable automatic checkpointing */ override val mustCheckpoint = true + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD @@ -141,3 +150,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } } } + +private[streaming] object InternalTrackStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index e3072b4442840..58aef74c0040f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -22,9 +22,10 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag +import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null @@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } + test("trackStateByKey - checkpoint durations") { + val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + try { + ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (value: Option[Int], state: State[Int]) => 0 + val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) + val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + trackStateStream.checkpoint(d) + } + trackStateStream.register() + ssc.start() // should initialize all the checkpoint durations + assert(trackStateStream.checkpointDuration === null) + assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], From 921900fd06362474f8caac675803d526a0986d70 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 18 Nov 2015 16:19:00 -0800 Subject: [PATCH 121/149] [SPARK-11791] Fix flaky test in BatchedWriteAheadLogSuite stack trace of failure: ``` org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 62 times over 1.006322071 seconds. Last failure message: Argument(s) are different! Wanted: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.BatchedWriteAheadLogSuite$$anonfun$23$$anonfun$apply$mcV$sp$15.apply(WriteAheadLogSuite.scala:518) Actual invocation has different arguments: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.WriteAheadLogSuite$BlockingWriteAheadLog.write(WriteAheadLogSuite.scala:756) ``` I believe the issue was that due to a race condition, the ordering of the events could be messed up in the final ByteBuffer, therefore the comparison fails. By adding eventually between the requests, we make sure the ordering is preserved. Note that in real life situations, the ordering across threads will not matter. Another solution would be to implement a custom mockito matcher that sorts and then compares the results, but that kind of sounds like overkill to me. Let me know what you think tdas zsxwing Author: Burak Yavuz Closes #9790 from brkyvz/fix-flaky-2. --- .../spark/streaming/util/WriteAheadLogSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7f80d6ecdbbb5..eaa88ea3cd380 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor import org.mockito.Matchers.{eq => meq} import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -507,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } blockingWal.allowWrite() - val buffer1 = wrapArrayArrayByte(Array(event1)) - val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) eventually(timeout(1 second)) { assert(batchedWal.invokePrivate(queueLength()) === 0) - verify(wal, times(1)).write(meq(buffer1), meq(3L)) + verify(wal, times(1)).write(meq(buffer), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. - verify(wal, times(1)).write(meq(buffer2), meq(10L)) + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) } } From 59a501359a267fbdb7689058693aa788703e54b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 16:48:09 -0800 Subject: [PATCH 122/149] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust Closes #9602 from marmbrus/dataset-replClasses. --- .../spark/sql/catalyst/ScalaReflection.scala | 81 ++++++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 26 +++++- .../sql/catalyst/encoders/OuterScopes.scala | 42 ++++++++++ .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 +-- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 6 ++ .../sql/catalyst/expressions/objects.scala | 42 ++++++++-- .../encoders/ExpressionEncoderSuite.scala | 7 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 19 ++--- 17 files changed, 193 insertions(+), 82 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 38828e59a2152..59ccf356f2c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,30 +42,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -114,7 +79,9 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } @@ -640,6 +607,48 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b977f278c5b5c..456b595008479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -211,7 +213,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -219,7 +223,23 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 0000000000000..a753b187bcd32 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 55c4ee11b20f4..2914c6ee790ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f34..a31574c251af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55a..40189f0877764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e1301..f229f2000d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923f..b7926bda3de19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a23725..7b6c9373ebe30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b3..da602d9b4bce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26d..e34fd49be8389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,12 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index acf0da240051e..f865a9408ef4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -178,6 +179,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -189,12 +199,15 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -209,30 +222,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e4..cde0364f3dd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537d..1798514c5c38b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad3096d..bdcdc5d47cbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a1025b..7e5acbe8517d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8c..6ce41aaf01e27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From e99d3392068bc929c900a4cc7b50e9e2b437a23a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 18:34:01 -0800 Subject: [PATCH 123/149] [SPARK-11839][ML] refactor save/write traits * add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng Closes #9827 from mengxr/SPARK-11839. --- .../scala/org/apache/spark/ml/Pipeline.scala | 40 +++++++++---------- .../classification/LogisticRegression.scala | 29 +++++++------- .../apache/spark/ml/feature/Binarizer.scala | 12 ++---- .../apache/spark/ml/feature/Bucketizer.scala | 12 ++---- .../spark/ml/feature/CountVectorizer.scala | 22 ++++------ .../org/apache/spark/ml/feature/DCT.scala | 12 ++---- .../apache/spark/ml/feature/HashingTF.scala | 12 ++---- .../org/apache/spark/ml/feature/IDF.scala | 23 +++++------ .../apache/spark/ml/feature/Interaction.scala | 12 ++---- .../spark/ml/feature/MinMaxScaler.scala | 22 ++++------ .../org/apache/spark/ml/feature/NGram.scala | 12 ++---- .../apache/spark/ml/feature/Normalizer.scala | 12 ++---- .../spark/ml/feature/OneHotEncoder.scala | 12 ++---- .../ml/feature/PolynomialExpansion.scala | 12 ++---- .../ml/feature/QuantileDiscretizer.scala | 12 ++---- .../spark/ml/feature/SQLTransformer.scala | 13 ++---- .../spark/ml/feature/StandardScaler.scala | 22 ++++------ .../spark/ml/feature/StopWordsRemover.scala | 12 ++---- .../spark/ml/feature/StringIndexer.scala | 32 +++++---------- .../apache/spark/ml/feature/Tokenizer.scala | 24 +++-------- .../spark/ml/feature/VectorAssembler.scala | 12 ++---- .../spark/ml/feature/VectorSlicer.scala | 12 ++---- .../apache/spark/ml/recommendation/ALS.scala | 27 +++++-------- .../ml/regression/LinearRegression.scala | 30 ++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 14 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++---- 27 files changed, 190 insertions(+), 321 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f42be..b0f22e042ec56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcbf47..a3cc49f7f018c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8f00..63c06581482ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70aa07..324353a96afb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d1119..4969cf42450d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a616173ee..6bed72164a1da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f5617..9e15835429a38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef12646..0e00ef6f2ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc9da..2181119f04a5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae834e..ed24eabb50444 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd95199f..65414ecbefbbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d5058..c2d514fd9629e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa21e..d70164eaf0224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 49415398325fd..08610593fadda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d2967..7bf67c6325a35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff301a..3a735017ba836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418dd4f..1f689c1da1ba9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dcc7c..318808596dc6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc002d8..97a2e4f6d6ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2fa7..8ad7bbedaab5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292ca2..0feec0549852b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b574..5410a50bc2e47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e239e..795b73c4c2121 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51b8a..7ba1a60edaf71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af5323..ff9322dba122a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895acb0c..12aba6bc6dbeb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce9418..84d06b43d6224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From e61367b9f9bfc8e123369d55d7ca5925568b98a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 18:34:36 -0800 Subject: [PATCH 124/149] [SPARK-11833][SQL] Add Java tests for Kryo/Java Dataset encoders Also added some nicer error messages for incompatible types (private types and primitive types) for Kryo/Java encoder. Author: Reynold Xin Closes #9823 from rxin/SPARK-11833. --- .../scala/org/apache/spark/sql/Encoder.scala | 69 +++++++++++------ .../encoders/EncoderErrorMessageSuite.scala | 40 ++++++++++ .../catalyst/encoders/FlatEncoderSuite.scala | 22 ++---- .../apache/spark/sql/JavaDatasetSuite.java | 75 ++++++++++++++++++- 4 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ed5111440c80..d54f2854fb33f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.reflect.Modifier + import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} @@ -43,30 +45,28 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - toRowExpressions = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -75,6 +75,8 @@ object Encoders { * serialization. This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -83,17 +85,40 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } def tuple[T1, T2]( e1: Encoder[T1], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 0000000000000..0b2a10bb04c10 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 6e0322fb6e019..07523d49f4266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -74,24 +74,14 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { FlatEncoder[Map[Int, Map[String, Int]]], "map of map") // Kryo encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.kryo[String]), - "kryo string") - encodeDecodeTest( - new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), - "kryo object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") + encodeDecodeTest(new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") // Java encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.javaSerialization[String]), - "java string") - encodeDecodeTest( - new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), - "java object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") + encodeDecodeTest(new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") } /** For testing Kryo serialization based encoder. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d9b22506fbd3b..ce40dd856f679 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; + import org.junit.*; import org.apache.spark.Accumulator; @@ -410,8 +411,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); Assert.assertEquals( Arrays.asList( - new Tuple4("a", 3, 3L, 2L), - new Tuple4("b", 3, 3L, 1L)), + new Tuple4<>("a", 3, 3L, 2L), + new Tuple4<>("b", 3, 3L, 1L)), agged2.collectAsList()); } @@ -437,4 +438,74 @@ public Integer finish(Integer reduction) { return reduction; } } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } } From 6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 18:38:45 -0800 Subject: [PATCH 125/149] [SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas. This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li Closes #9774 from nongli/parquet. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 41 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 + .../expressions/codegen/BufferHolder.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 20 +- .../SpecificParquetRecordReaderBase.java | 240 +++++++ .../parquet/UnsafeRowParquetRecordReader.java | 593 ++++++++++++++++++ .../parquet/CatalystRowConverter.scala | 48 +- .../parquet/ParquetFilterSuite.scala | 4 +- 8 files changed, 944 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f39085..4d176332b69ce 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a4..33769363a0ed5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065d..d26b1b187c27b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb4..e227c0dec9748 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public void initialize(BufferHolder holder, int numFields) { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public void initialize(BufferHolder holder, int numFields) { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 0000000000000..2ed30c1f5a8d9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 0000000000000..8a92e489ccb7c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb1..94298fae2d69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77af3f..c8028a5ef5528 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => From 9c0654d36c6d171dd273850c2cc2f415cc2a5a6b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 18:41:40 -0800 Subject: [PATCH 126/149] Revert "[SPARK-11544][SQL] sqlContext doesn't use PathFilter" This reverts commit 54db79702513e11335c33bcf3a03c59e965e6f16. --- .../apache/spark/sql/sources/interfaces.scala | 25 +++---------- .../datasources/json/JsonSuite.scala | 36 ++----------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936d..b3d3bdf50df63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -448,15 +447,9 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + logInfo(s"Listing $qualified on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - } + Try(fs.listStatus(qualified)).getOrElse(Array.empty) }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -854,16 +847,8 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f09b61e838159..6042b1178affe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,27 +19,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TestFileFilter extends PathFilter { - override def accept(path: Path): Boolean = path.getParent.getName != "p=2" -} - class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1398,28 +1390,4 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } - - test("SPARK-11544 test pathfilter") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(2) - df.write.json(path + "/p=1") - df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) - - val clonedConf = new Configuration(hadoopConfiguration) - try { - hadoopConfiguration.setClass( - "mapreduce.input.pathFilter.class", - classOf[TestFileFilter], - classOf[PathFilter]) - assert(sqlContext.read.json(path).count() === 2) - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - } } From 67c75828ff4df2e305bdf5d6be5a11201d1da3f3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 18:49:46 -0800 Subject: [PATCH 127/149] [SPARK-11816][ML] fix some style issue in ML/MLlib examples jira: https://issues.apache.org/jira/browse/SPARK-11816 Currently I only fixed some obvious comments issue like // scalastyle:off println on the bottom. Yet the style in examples is not quite consistent, like only half of the examples are with // Example usage: ./bin/run-example mllib.FPGrowthExample \, Author: Yuhao Yang Closes #9808 from hhbyyh/exampleStyle. --- .../java/org/apache/spark/examples/ml/JavaKMeansExample.java | 2 +- .../apache/spark/examples/ml/AFTSurvivalRegressionExample.scala | 2 +- .../spark/examples/ml/DecisionTreeClassificationExample.scala | 1 + .../spark/examples/ml/DecisionTreeRegressionExample.scala | 1 + .../examples/ml/MultilayerPerceptronClassifierExample.scala | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465c..47665ff2b1f3c 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -41,7 +41,7 @@ * An example demonstrating a k-means clustering. * Run with *
    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 5da285e83681f..f4b3613ccb94f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -59,4 +59,4 @@ object AFTSurvivalRegressionExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index ff8a0a90f1e44..db024b5cad935 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -90,3 +90,4 @@ object DecisionTreeClassificationExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index fc402724d2156..ad01f55df72b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -78,3 +78,4 @@ object DecisionTreeRegressionExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 146b83c8be490..9c98076bd24b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,4 +66,4 @@ object MultilayerPerceptronClassifierExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println From fc3f77b42d62ca789d0ee07403795978961991c7 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 18 Nov 2015 19:37:14 -0800 Subject: [PATCH 128/149] [SPARK-11614][SQL] serde parameters should be set only when all params are ready see HIVE-7975 and HIVE-12373 With changed semantic of setters in thrift objects in hive, setter should be called only after all parameters are set. It's not problem of current state but will be a problem in some day. Author: navis.ryu Closes #9580 from navis/SPARK-11614. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4e..9a981d02ad67c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -804,12 +804,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH 129/149] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042ec56..6f15b37abcb30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f018c..418bbdc9a058f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf42450d2..b9e2144c0ad40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2ee20..f7b0f29a27c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb50444..c2866f5eceff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1ba9..6d545219ebf49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6ca4..5c40c35eeaa48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c2121..4d35177ad9b0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60edaf71..70ccec766c471 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb630685..a9a6ff8a783d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 From 1a93323c5bab18ed7e55bf6f7b13aae88cb9721c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 18 Nov 2015 23:32:49 -0800 Subject: [PATCH 130/149] [SPARK-11339][SPARKR] Document the list of functions in R base package that are masked by functions with same name in SparkR Added tests for function that are reported as masked, to make sure the base:: or stats:: function can be called. For those we can't call, added them to SparkR programming guide. It would seem to me `table, sample, subset, filter, cov` not working are not actually expected - I investigated/experimented with them but couldn't get them to work. It looks like as they are defined in base or stats they are missing the S3 generic, eg. ``` > methods("transform") [1] transform,ANY-method transform.data.frame [3] transform,DataFrame-method transform.default see '?methods' for accessing help and source code > methods("subset") [1] subset.data.frame subset,DataFrame-method subset.default [4] subset.matrix see '?methods' for accessing help and source code Warning message: In .S3methods(generic.function, class, parent.frame()) : function 'subset' appears not to be S3 generic; found functions that look like S3 methods ``` Any idea? More information on masking: http://www.ats.ucla.edu/stat/r/faq/referencing_objects.htm http://www.sfu.ca/~sweldon/howTo/guide4.pdf This is what the output doc looks like (minus css): ![image](https://cloud.githubusercontent.com/assets/8969467/11229714/2946e5de-8d4d-11e5-94b0-dda9696b6fdd.png) Author: felixcheung Closes #9785 from felixcheung/rmasked. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 2 +- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/test_mllib.R | 5 +++++ R/pkg/inst/tests/test_sparkSQL.R | 33 +++++++++++++++++++++++++++- docs/sparkr.md | 37 +++++++++++++++++++++++++++++++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34177e3cdd94f..06b0108b1389e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,7 +2152,7 @@ setMethod("with", }) #' Returns the column types of a DataFrame. -#' +#' #' @name coltypes #' @title Get column types of a DataFrame #' @family dataframe_funcs diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ff0f438045c14..25a1f22101494 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,7 +2204,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0dcd05438222b..71004a05ba611 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export @@ -790,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index d497ad8c9daa3..e0667e5e22c18 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -31,6 +31,11 @@ test_that("glm and predict", { model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("glm should work with long formula", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d9a94faff7ac0..3f4f319fe745d 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -433,6 +433,10 @@ test_that("table() returns a new DataFrame", { expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") }) test_that("toRDD() returns an RRDD", { @@ -673,6 +677,9 @@ test_that("sample on a DataFrame", { # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) }) test_that("select operators", { @@ -753,6 +760,9 @@ test_that("subsetting", { df6 <- subset(df, df$age %in% c(30), c(1,2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) }) test_that("selectExpr() on a DataFrame", { @@ -888,6 +898,9 @@ test_that("column functions", { expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) }) # test_that("column binary mathfunctions", { @@ -1086,7 +1099,7 @@ test_that("group by, agg functions", { gd3_local <- collect(agg(gd3, var(df8$age))) expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) - # make sure base:: or stats::sd, var are working + # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) @@ -1138,6 +1151,9 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered5), 1) filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) }) test_that("join() and merge() on a DataFrame", { @@ -1284,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_is(unioned, "DataFrame") expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) }) test_that("withColumn() and withColumnRenamed()", { @@ -1365,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { @@ -1448,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) }) test_that("fillna() on a DataFrame", { @@ -1510,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) result <- corr(df, "singles", "doubles", "pearson") expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) }) test_that("freqItems() on a DataFrame", { diff --git a/docs/sparkr.md b/docs/sparkr.md index a744b76be7466..cfb9b41350f45 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,7 +286,7 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). @@ -351,3 +351,38 @@ summary(model) ##Sepal_Width 0.404655 {% endhighlight %} + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + From f449992009becc8f7c7f06cda522b9beaa1e263c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 10:48:04 -0800 Subject: [PATCH 131/149] [SPARK-11849][SQL] Analyzer should replace current_date and current_timestamp with literals We currently rely on the optimizer's constant folding to replace current_timestamp and current_date. However, this can still result in different values for different instances of current_timestamp/current_date if the optimizer is not running fast enough. A better solution is to replace these functions in the analyzer in one shot. Author: Reynold Xin Closes #9833 from rxin/SPARK-11849. --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f00c451b5981a..84781cd57f3dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -65,9 +65,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -84,7 +83,8 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic), + PullOutNondeterministic, + ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1076,7 +1076,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case plan => plan transformExpressionsUp { + case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 08586a97411ae..e051069951887 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest { udf4) // checkUDF(udf4, expected4) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } From 962878843b611fa6229e3ee67bb22e2a4bc283cd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Nov 2015 11:02:17 -0800 Subject: [PATCH 132/149] [SPARK-11840][SQL] Restore the 1.5's behavior of planning a single distinct aggregation. The impact of this change is for a query that has a single distinct column and does not have any grouping expression like `SELECT COUNT(DISTINCT a) FROM table` The plan will be changed from ``` AGG-2 (count distinct) Shuffle to a single reducer Partial-AGG-2 (count distinct) AGG-1 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on 1) ``` to the following one (1.5 uses this) ``` AGG-2 AGG-1 (grouping on a) Shuffle to a single reducer Partial-AGG-1(grouping on a) ``` The first plan is more robust. However, to better benchmark the impact of this change, we should use 1.5's plan and use the conf of `spark.sql.specializeSingleDistinctAggPlanning` to control the plan. Author: Yin Huai Closes #9828 from yhuai/distinctRewriter. --- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index c0c960471a61a..9c78f6d4cc71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,8 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { // When the flag is set to specialize single distinct agg planning, // we will rely on our Aggregation strategy to handle queries with a single - // distinct column and this aggregate operator does have grouping expressions. - distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + // distinct column. + distinctAggGroups.size > 1 } else { distinctAggGroups.size >= 1 } From 72d150c271d2b206148fd0917a0def263445121b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 19 Nov 2015 11:57:50 -0800 Subject: [PATCH 133/149] [SPARK-11830][CORE] Make NettyRpcEnv bind to the specified host This PR includes the following change: 1. Bind NettyRpcEnv to the specified host 2. Fix the port information in the log for NettyRpcEnv. 3. Fix the service name of NettyRpcEnv. Author: zsxwing Author: Shixiong Zhu Closes #9821 from zsxwing/SPARK-11830. --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 9 ++++++++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 7 +++---- .../org/apache/spark/network/TransportContext.java | 8 +++++++- .../spark/network/server/TransportServer.java | 14 ++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4474a83bedbdb..88df27f733f2a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -258,8 +258,15 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { + val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3e0c497969502..3ce359868039b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -102,7 +102,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -337,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -370,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 1b64b863a9fe5..238710d17249a 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -94,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8d..baae235e02205 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); From 276a7e130252c0e7aba702ae5570b3c4f424b23b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:45:04 -0800 Subject: [PATCH 134/149] [SPARK-11633][SQL] LogicalRDD throws TreeNode Exception : Failed to Copy Node When handling self joins, the implementation did not consider the case insensitivity of HiveContext. It could cause an exception as shown in the JIRA: ``` TreeNodeException: Failed to copy node. ``` The fix is low risk. It avoids unnecessary attribute replacement. It should not affect the existing behavior of self joins. Also added the test case to cover this case. Author: gatorsmile Closes #9762 from gatorsmile/joinMakeCopy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 62620ec642c78..623348f6768a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6399b0165c4c3..dd6d06512ff60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1110,6 +1110,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) From 7d4aba18722727c85893ad8d8f07d4494665dcfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:46:36 -0800 Subject: [PATCH 135/149] [SPARK-11848][SQL] Support EXPLAIN in DataSet APIs When debugging DataSet API, I always need to print the logical and physical plans. I am wondering if we should provide a simple API for EXPLAIN? Author: gatorsmile Closes #9832 from gatorsmile/explainDS. --- .../org/apache/spark/sql/DataFrame.scala | 23 +------------------ .../spark/sql/execution/Queryable.scala | 21 +++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ba4ba18d2122..98358127e2709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -308,27 +308,6 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) // scalastyle:on println - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09b..e86a52c149a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -25,6 +26,7 @@ import scala.util.control.NonFatal private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution + def sqlContext: SQLContext override def toString: String = { try { @@ -34,4 +36,23 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(): Unit = explain(extended = false) } From 47d1c2325caaf9ffe31695b6fff529314b8582f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2015 12:54:25 -0800 Subject: [PATCH 136/149] [SPARK-11750][SQL] revert SPARK-11727 and code clean up After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan Closes #9726 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Encoder.scala | 16 +- .../spark/sql/catalyst/ScalaReflection.scala | 354 +++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 19 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 -- .../catalyst/encoders/ProductEncoder.scala | 452 ------------------ .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/expressions/objects.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 68 --- .../encoders/ExpressionEncoderSuite.scala | 218 ++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 99 ---- .../encoders/ProductEncoderSuite.scala | 156 ------ .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 23 +- .../org/apache/spark/sql/functions.scala | 4 +- 14 files changed, 364 insertions(+), 1118 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index d54f2854fb33f..86bb536459035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable { */ object Encoders { - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 59ccf356f2c48..33ae700706dae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - ObjectType(clazz) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } } /** @@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -108,6 +98,15 @@ object ScalaReflection extends ScalaReflection { ObjectType(cls) } + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + /** * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection { * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = - path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + WrapOption(constructorFor(optType, path)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } val keyData = Invoke( MapObjects( p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) @@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) @@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - } } /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + extractorFor(inputObject, localTypeOf[T]) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def extractorFor( inputObject: Expression, tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { @@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection { DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal, @@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) - case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } } } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } } /** @@ -635,8 +545,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -736,39 +645,4 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } - - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. - } - - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { - - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 456b595008479..6eeba1442c1f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] new ExpressionEncoder[T]( - extractExpression.dataType, + toRowExpression.dataType, flat, - extractExpression.flatten, - constructExpression, + toRowExpression.flatten, + fromRowExpression, ClassTag[T](cls)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala deleted file mode 100644 index 6d307ab13a9fc..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} -import org.apache.spark.sql.catalyst.ScalaReflection - -object FlatEncoder { - import ScalaReflection.schemaFor - import ScalaReflection.dataTypeFor - - def apply[T : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) - - val input = BoundReference(0, dataTypeFor(tpe), nullable = true) - val toRowExpression = CreateNamedStruct( - Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) - val fromRowExpression = ProductEncoder.constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = true, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 2914c6ee790ce..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.util.Utils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} - -import scala.reflect.ClassTag - -object ProductEncoder { - import ScalaReflection.universe._ - import ScalaReflection.mirror - import ScalaReflection.localTypeOf - import ScalaReflection.dataTypeFor - import ScalaReflection.Schema - import ScalaReflection.schemaFor - import ScalaReflection.arrayClassFor - - def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] - val fromRowExpression = constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = false, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - - val unwrapped = UnwrapOption(optionObjectType, inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) - } - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case other => - throw new UnsupportedOperationException(s"Encoder for type $other is not supported") - } - } - } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (RowEncoder.isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } - - def constructorFor( - tpe: `Type`, - path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) - - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) - - tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - WrapOption(null, constructorFor(optType, path)) - - case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) - - case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) - } - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keyData = - Invoke( - MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType - - // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) - } else { - constructorFor(fieldType, Some(addToPath(fieldName))) - } - } - - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) - - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { - newInstance - } - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9bb1602494b68..4cda4824acdc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,17 +133,8 @@ object RowEncoder { CreateStruct(convertedFields) } - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true - case _ => false - } - private def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if isNativeType(dt) => dt + case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f865a9408ef4e..ef7399e0196ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -300,10 +299,9 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. */ -case class WrapOption(optionType: DataType, child: Expression) +case class WrapOption(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 4ea410d492b01..c2aace1ef238e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cde0364f3dd9d..76459b34a484f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,24 +17,234 @@ package org.apache.spark.sql.catalyst.encoders +import java.sql.{Timestamp, Date} import java.util.Arrays import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.ArrayType -abstract class ExpressionEncoderSuite extends SparkFunSuite { - val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() +case class RepeatedStruct(s: Seq[PrimitiveData]) - protected def encodeDecodeTest[T]( +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, - encoder: ExpressionEncoder[T], testName: String): Unit = { test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema, outers).bind(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala deleted file mode 100644 index 07523d49f4266..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Encoders - -class FlatEncoderSuite extends ExpressionEncoderSuite { - encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") - encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") - encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") - encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") - encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") - encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") - encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") - - encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") - - encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") - type JDecimal = java.math.BigDecimal - // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") - - encodeDecodeTest("hello", FlatEncoder[String], "string") - encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") - encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") - encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") - - encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") - encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") - encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") - encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") - encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") - - encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), - FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") - encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), - FlatEncoder[Seq[Seq[String]]], "seq of seq of string") - - encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") - encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") - encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") - encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") - encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") - - encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), - FlatEncoder[Array[Array[Int]]], "array of array of int") - encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), - FlatEncoder[Array[Array[String]]], "array of array of string") - - encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") - encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") - encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), - FlatEncoder[Map[Int, Map[String, Int]]], "map of map") - - // Kryo encoders - encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") - encodeDecodeTest(new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") - - // Java encoders - encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") - encodeDecodeTest(new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") -} - -/** For testing Kryo serialization based encoder. */ -class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value - } -} - -/** For testing Java serialization based encoder. */ -class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala deleted file mode 100644 index 1798514c5c38b..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) { - override def equals(other: Any): Boolean = other match { - case NestedArray(otherArray) => - java.util.Arrays.deepEquals( - a.asInstanceOf[Array[AnyRef]], - otherArray.asInstanceOf[Array[AnyRef]]) - case _ => false - } -} - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ProductEncoderSuite extends ExpressionEncoderSuite { - outers.put(getClass.getName, this) - - case class InnerClass(i: Int) - productTest(InnerClass(1)) - - productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - productTest( - OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - productTest(OptionalData(None, None, None, None, None, None, None, None)) - - productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - productTest(BoxedData(null, null, null, null, null, null, null)) - - productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - - productTest(("Seq[(String, String)]", - Seq(("a", "b")))) - productTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - productTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - productTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - productTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - productTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - productTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - productTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - productTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - productTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - productTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - productTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - productTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - productTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - productTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - productTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - productTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTest( - 1 -> 10L, - ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), - "tuple with 2 flat encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), - "tuple with 2 product encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), - "tuple with flat encoder and product encoder") - - encodeDecodeTest( - (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), - ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), - "tuple with product encoder and flat encoder") - - encodeDecodeTest( - (1, (10, 100L)), - { - val intEnc = FlatEncoder[Int] - val longEnc = FlatEncoder[Long] - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) - }, - "nested tuple encoder") - - private def productTest[T <: Product : TypeTag](input: T): Unit = { - encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7e5acbe8517d1..6de3dd626576a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8471eea1b7d9c..25ffdcde17717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String @@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] - implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] - implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] - implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] - implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] - implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] - implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] - implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** * Creates a [[Dataset]] from an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 95158de710acf..b27b1340cce46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.FlatEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(FlatEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From 4700074530d9a398843e13f0ef514be97a237cea Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 19 Nov 2015 13:08:01 -0800 Subject: [PATCH 137/149] [SPARK-11778][SQL] parse table name before it is passed to lookupRelation Fix a bug in DataFrameReader.table (table with schema name such as "db_name.table" doesn't work) Use SqlParser.parseTableIdentifier to parse the table name before lookupRelation. Author: Huaxin Gao Closes #9773 from huaxingao/spark-11778. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../spark/sql/hive/HiveDataFrameAnalyticsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5872fbded3833..dcb3737b70fbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,7 +313,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf765265..f19a74d4b3724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,10 +34,14 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -74,4 +78,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } + + // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, + // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException + test("table name with schema") { + hiveContext.read.table("usrdb.test") + } } From 599a8c6e2bf7da70b20ef3046f5ce099dfd637f8 Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Thu, 19 Nov 2015 13:57:23 -0800 Subject: [PATCH 138/149] [SPARK-11812][PYSPARK] invFunc=None works properly with python's reduceByKeyAndWindow invFunc is optional and can be None. Instead of invFunc (the parameter) invReduceFunc (a local function) was checked for trueness (that is, not None, in this context). A local function is never None, thus the case of invFunc=None (a common one when inverse reduction is not defined) was treated incorrectly, resulting in loss of data. In addition, the docstring used wrong parameter names, also fixed. Author: David Tolpin Closes #9775 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 6 +++--- python/pyspark/streaming/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18d..acec850f02c2d 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ def invReduceFunc(t, a, b): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f15532b5..3403f6d20d789 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): From 8a24eb41e34953196a2ff3fa164d82c814bb3a79 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Fri, 13 Nov 2015 11:53:16 -0500 Subject: [PATCH 139/149] [SPARK-6990] Remove whitespace after '>' Closing '>' and method name shouldn't have whitespace between them, according to http://checkstyle.sourceforge.net/config_whitespace.html#GenericWhitespace --- .../org/apache/spark/streaming/JavaTrackStateByKeySuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index eac4cdd14a683..89d0bb7b617e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -95,7 +95,7 @@ public Double call(Optional one, State state) { JavaTrackStateDStream stateDstream2 = wordsDstream.trackStateByKey( - StateSpec. function(trackStateFunc2) + StateSpec.function(trackStateFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) From 1df0b862b7d8cd6223f2dd752e9925bf9a4014a5 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Tue, 10 Nov 2015 11:56:46 -0500 Subject: [PATCH 140/149] [SPARK-6990] Add Java linting script Invoke Checkstyle and print any errors to the console, failing the step. Use Google's style rules modified according to https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide Some important checks are disabled (see TODOs in checkstyle.xml) due to multiple violations being present in the codebase. --- checkstyle.xml | 159 +++++++++++++++++++++++++++++++ dev/lint-java | 34 +++++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 ++ dev/sparktestsupport/__init__.py | 1 + pom.xml | 24 +++++ 6 files changed, 226 insertions(+) create mode 100644 checkstyle.xml create mode 100755 dev/lint-java diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 0000000000000..403c5356f7dd8 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,159 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 0000000000000..55b6b0c348416 --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +$SCRIPT_DIR/../build/mvn -Pkinesis-asl -Phive -Phive-thriftserver checkstyle:check > checkstyle.txt +$SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phadoop-2.2 checkstyle:check >> checkstyle.txt + +ERRORS=$(grep ERROR checkstyle.txt) +rm checkstyle.txt + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e189..622e5b6a17655 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,6 +119,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb0697192..e7e10f1d8c725 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,6 +198,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -522,6 +527,8 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2f..0e8032d13341e 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,6 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, "BLOCK_TIMEOUT": 124 } diff --git a/pom.xml b/pom.xml index ad849112ce76c..cf9bdefed6993 100644 --- a/pom.xml +++ b/pom.xml @@ -2258,6 +2258,30 @@
    + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + false + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins From 220dcb9869c4d5ae0743ae521e44304a877ecf61 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Tue, 10 Nov 2015 11:57:20 -0500 Subject: [PATCH 141/149] [SPARK-6990] Fix some Checkstyle warnings --- .../apache/spark/util/collection/TimSort.java | 118 ++++++++++++------ .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../spark/examples/mllib/JavaLDAExample.java | 3 +- .../mllib/JavaRecommendationExample.java | 2 +- .../streaming/JavaSqlNetworkWordCount.java | 4 +- .../shuffle/ExternalShuffleBlockResolver.java | 2 +- .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/types/SQLUserDefinedType.java | 2 +- .../spark/streaming/util/WriteAheadLog.java | 10 +- .../apache/spark/tags/ExtendedHiveTest.java | 1 + .../apache/spark/tags/ExtendedYarnTest.java | 1 + .../apache/spark/unsafe/types/UTF8String.java | 8 +- 13 files changed, 98 insertions(+), 59 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 40b5fb7fe4b49..0d476dc0748fd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,8 +120,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) + if (nRemaining < 2) { return; // Arrays of size 0 and 1 are always sorted + } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -184,8 +185,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) + if (start == lo) { start++; + } K key0 = s.newKey(); K key1 = s.newKey(); @@ -206,10 +208,11 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { right = mid; - else + } else { left = mid + 1; + } } assert left == right; @@ -260,20 +263,23 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) + if (runHi == hi) { return 1; + } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { runHi++; + } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { runHi++; + } } return runHi - lo; @@ -445,8 +451,9 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) + if (runLen[n - 1] < runLen[n + 1]) { n--; + } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -461,8 +468,9 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) + if (n > 0 && runLen[n - 1] < runLen[n + 1]) { n--; + } mergeAt(n); } } @@ -508,8 +516,9 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) + if (len1 == 0) { return; + } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -517,14 +526,16 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) + if (len2 == 0) { return; + } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) + if (len1 <= len2) { mergeLo(base1, len1, base2, len2); - else + } else { mergeHi(base1, len1, base2, len2); + } } /** @@ -557,11 +568,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base lastOfs += hint; @@ -572,11 +585,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base int tmp = lastOfs; @@ -594,10 +609,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { lastOfs = m + 1; // a[base + m] < key - else + } else { ofs = m; // key <= a[base + m] + } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -629,11 +645,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b int tmp = lastOfs; @@ -645,11 +663,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b lastOfs += hint; @@ -666,10 +686,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { ofs = m; // key < a[b + m] - else + } else { lastOfs = m + 1; // a[b + m] <= key + } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -735,14 +756,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) + if (--len2 == 0) { break outer; + } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) + if (--len1 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -759,12 +782,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) // len1 == 1 || len1 == 0 + if (len1 <= 1) { // len1 == 1 || len1 == 0 break outer; + } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) + if (--len2 == 0) { break outer; + } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -772,16 +797,19 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) + if (len2 == 0) { break outer; + } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) + if (--len1 == 1) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -857,14 +885,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) + if (--len1 == 0) { break outer; + } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) + if (--len2 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -881,12 +911,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) + if (len1 == 0) { break outer; + } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) + if (--len2 == 1) { break outer; + } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -894,16 +926,19 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) // len2 == 1 || len2 == 0 + if (len2 <= 1) { // len2 == 1 || len2 == 0 break outer; + } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) + if (--len1 == 0) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -943,10 +978,11 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) // Not bloody likely! + if (newSize < 0) { // Not bloody likely! newSize = minCapacity; - else + } else { newSize = Math.min(newSize, aLength >>> 1); + } tmp = s.allocate(newSize); tmpLength = newSize; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f46..9ce5fd2a1c8ef 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -180,7 +180,7 @@ private SortedIterator( this.sortBuffer = sortBuffer; } - public SortedIterator clone () { + public SortedIterator clone() { SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); iter.position = position; iter.baseObject = baseObject; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d479..ea83e8fef9eb9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc4974..de8e739ac9256 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index 1065fde953b96..c179e7578cdfa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -29,7 +29,7 @@ // $example off$ public class JavaRecommendationExample { - public static void main(String args[]) { + public static void main(String[] args) { // $example on$ SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); JavaSparkContext jsc = new JavaSparkContext(conf); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb57..3515d7be45d37 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac769..e5cb68c8a4dbb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -419,7 +419,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); public final int major; public final int minor; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3986d6e18f770..352002b3499a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b36..1e4e5ede8cc11 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c2..2803cad8095dd 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4e..83279e5e93c0e 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf0..108300168e173 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4bd3fd7772079..5b61386808769 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -900,9 +900,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -965,7 +965,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; From 4374ecc1a2114756b259613f2001667e33bb0288 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 09:52:19 -0500 Subject: [PATCH 142/149] [SPARK-6990] Suppress Checkstyle for TimSort The code was copied from a third-party source and needs to be in sync with that, so we shouldn't make our own modifications to it. The file contains some style violations, so suppress the checks. --- checkstyle-suppressions.xml | 33 +++++ checkstyle.xml | 5 + .../apache/spark/util/collection/TimSort.java | 118 ++++++------------ 3 files changed, 79 insertions(+), 77 deletions(-) create mode 100644 checkstyle-suppressions.xml diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 0000000000000..9242be3d0357a --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml index 403c5356f7dd8..b273e7be27a11 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -47,6 +47,11 @@ + + + + + diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 0d476dc0748fd..40b5fb7fe4b49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,9 +120,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) { + if (nRemaining < 2) return; // Arrays of size 0 and 1 are always sorted - } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -185,9 +184,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) { + if (start == lo) start++; - } K key0 = s.newKey(); K key1 = s.newKey(); @@ -208,11 +206,10 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; - } else { + else left = mid + 1; - } } assert left == right; @@ -263,23 +260,20 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) { + if (runHi == hi) return 1; - } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) runHi++; - } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; - } } return runHi - lo; @@ -451,9 +445,8 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) { + if (runLen[n - 1] < runLen[n + 1]) n--; - } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -468,9 +461,8 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) { + if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; - } mergeAt(n); } } @@ -516,9 +508,8 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) { + if (len1 == 0) return; - } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -526,16 +517,14 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) { + if (len2 == 0) return; - } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) { + if (len1 <= len2) mergeLo(base1, len1, base2, len2); - } else { + else mergeHi(base1, len1, base2, len2); - } } /** @@ -568,13 +557,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base lastOfs += hint; @@ -585,13 +572,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base int tmp = lastOfs; @@ -609,11 +594,10 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) lastOfs = m + 1; // a[base + m] < key - } else { + else ofs = m; // key <= a[base + m] - } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -645,13 +629,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b int tmp = lastOfs; @@ -663,13 +645,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b lastOfs += hint; @@ -686,11 +666,10 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) ofs = m; // key < a[b + m] - } else { + else lastOfs = m + 1; // a[b + m] <= key - } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -756,16 +735,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) { + if (--len2 == 0) break outer; - } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) { + if (--len1 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -782,14 +759,12 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) { // len1 == 1 || len1 == 0 + if (len1 <= 1) // len1 == 1 || len1 == 0 break outer; - } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) { + if (--len2 == 0) break outer; - } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -797,19 +772,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) { + if (len2 == 0) break outer; - } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) { + if (--len1 == 1) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -885,16 +857,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) { + if (--len1 == 0) break outer; - } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) { + if (--len2 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -911,14 +881,12 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) { + if (len1 == 0) break outer; - } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) { + if (--len2 == 1) break outer; - } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -926,19 +894,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) { // len2 == 1 || len2 == 0 + if (len2 <= 1) // len2 == 1 || len2 == 0 break outer; - } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) { + if (--len1 == 0) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -978,11 +943,10 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) { // Not bloody likely! + if (newSize < 0) // Not bloody likely! newSize = minCapacity; - } else { + else newSize = Math.min(newSize, aLength >>> 1); - } tmp = s.allocate(newSize); tmpLength = newSize; From f8147f37c2ccd5ce58408c846058b1340df2e71c Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 09:58:41 -0500 Subject: [PATCH 143/149] [SPARK-6990] Fix some Checkstyle issues in tests --- .../map/AbstractBytesToBytesMapSuite.java | 4 +- .../ml/feature/JavaStringIndexerSuite.java | 6 +- .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../apache/spark/sql/hive/test/Complex.java | 86 +++++++++++++------ .../apache/spark/streaming/JavaAPISuite.java | 4 +- 6 files changed, 69 insertions(+), 35 deletions(-) diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d87a1d2a56d99..a5c583f9f2844 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -356,8 +356,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c342..b2df79ba74feb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -58,7 +58,7 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() @@ -67,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46c..225a216270b3b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 1c2fa4d0d462c..e5f6598f54fdf 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -59,7 +59,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb9327..4ef1f276d1bbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 609bb4413b6b1..9722c60bba1c3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1332,12 +1332,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), From 15fe047edca0f025ef75ffab14eb514050ca8417 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 10:01:20 -0500 Subject: [PATCH 144/149] [SPARK-6990] Enable Checkstyle for tests --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index cf9bdefed6993..8ac5a0e56c2b4 100644 --- a/pom.xml +++ b/pom.xml @@ -2265,7 +2265,7 @@ false false - false + true false ${basedir}/src/main/java ${basedir}/src/test/java From 6f8e292e0381b8c3de36664828f6879009843d46 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 10:01:48 -0500 Subject: [PATCH 145/149] [SPARK-6990] Enable FallThrough check in Checkstyle This makes sure all case statements end with a break. See http://checkstyle.sourceforge.net/config_coding.html#FallThrough --- checkstyle.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/checkstyle.xml b/checkstyle.xml index b273e7be27a11..a63b1b7797201 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -88,6 +88,7 @@ + From b63df72cf193cf5c8dc108af1c5f7447f5239d92 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Tue, 10 Nov 2015 11:57:20 -0500 Subject: [PATCH 146/149] [SPARK-6990] Fix some Checkstyle warnings --- .../apache/spark/util/collection/TimSort.java | 118 ++++++++++++------ 1 file changed, 77 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 40b5fb7fe4b49..0d476dc0748fd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,8 +120,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) + if (nRemaining < 2) { return; // Arrays of size 0 and 1 are always sorted + } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -184,8 +185,9 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) + if (start == lo) { start++; + } K key0 = s.newKey(); K key1 = s.newKey(); @@ -206,10 +208,11 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { right = mid; - else + } else { left = mid + 1; + } } assert left == right; @@ -260,20 +263,23 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) + if (runHi == hi) { return 1; + } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { runHi++; + } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { runHi++; + } } return runHi - lo; @@ -445,8 +451,9 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) + if (runLen[n - 1] < runLen[n + 1]) { n--; + } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -461,8 +468,9 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) + if (n > 0 && runLen[n - 1] < runLen[n + 1]) { n--; + } mergeAt(n); } } @@ -508,8 +516,9 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) + if (len1 == 0) { return; + } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -517,14 +526,16 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) + if (len2 == 0) { return; + } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) + if (len1 <= len2) { mergeLo(base1, len1, base2, len2); - else + } else { mergeHi(base1, len1, base2, len2); + } } /** @@ -557,11 +568,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base lastOfs += hint; @@ -572,11 +585,13 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to base int tmp = lastOfs; @@ -594,10 +609,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { lastOfs = m + 1; // a[base + m] < key - else + } else { ofs = m; // key <= a[base + m] + } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -629,11 +645,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b int tmp = lastOfs; @@ -645,11 +663,13 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow + if (ofs <= 0) { // int overflow ofs = maxOfs; + } } - if (ofs > maxOfs) + if (ofs > maxOfs) { ofs = maxOfs; + } // Make offsets relative to b lastOfs += hint; @@ -666,10 +686,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { ofs = m; // key < a[b + m] - else + } else { lastOfs = m + 1; // a[b + m] <= key + } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -735,14 +756,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) + if (--len2 == 0) { break outer; + } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) + if (--len1 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -759,12 +782,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) // len1 == 1 || len1 == 0 + if (len1 <= 1) { // len1 == 1 || len1 == 0 break outer; + } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) + if (--len2 == 0) { break outer; + } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -772,16 +797,19 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) + if (len2 == 0) { break outer; + } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) + if (--len1 == 1) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -857,14 +885,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) + if (--len1 == 0) { break outer; + } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) + if (--len2 == 1) { break outer; + } } } while ((count1 | count2) < minGallop); @@ -881,12 +911,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) + if (len1 == 0) { break outer; + } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) + if (--len2 == 1) { break outer; + } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -894,16 +926,19 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) // len2 == 1 || len2 == 0 + if (len2 <= 1) { // len2 == 1 || len2 == 0 break outer; + } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) + if (--len1 == 0) { break outer; + } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) + if (minGallop < 0) { minGallop = 0; + } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -943,10 +978,11 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) // Not bloody likely! + if (newSize < 0) { // Not bloody likely! newSize = minCapacity; - else + } else { newSize = Math.min(newSize, aLength >>> 1); + } tmp = s.allocate(newSize); tmpLength = newSize; From 8b2d2bf8c19b3c1c49fbb1407328adb062c59767 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Wed, 11 Nov 2015 09:52:19 -0500 Subject: [PATCH 147/149] [SPARK-6990] Suppress Checkstyle for TimSort The code was copied from a third-party source and needs to be in sync with that, so we shouldn't make our own modifications to it. The file contains some style violations, so suppress the checks. --- .../apache/spark/util/collection/TimSort.java | 118 ++++++------------ 1 file changed, 41 insertions(+), 77 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 0d476dc0748fd..40b5fb7fe4b49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -120,9 +120,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; - if (nRemaining < 2) { + if (nRemaining < 2) return; // Arrays of size 0 and 1 are always sorted - } // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { @@ -185,9 +184,8 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { @SuppressWarnings("fallthrough") private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo <= start && start <= hi; - if (start == lo) { + if (start == lo) start++; - } K key0 = s.newKey(); K key1 = s.newKey(); @@ -208,11 +206,10 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) { + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; - } else { + else left = mid + 1; - } } assert left == right; @@ -263,23 +260,20 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { assert lo < hi; int runHi = lo + 1; - if (runHi == hi) { + if (runHi == hi) return 1; - } K key0 = s.newKey(); K key1 = s.newKey(); // Find end of run, and reverse range if descending if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) runHi++; - } reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) { + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; - } } return runHi - lo; @@ -451,9 +445,8 @@ private void mergeCollapse() { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { - if (runLen[n - 1] < runLen[n + 1]) { + if (runLen[n - 1] < runLen[n + 1]) n--; - } } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } @@ -468,9 +461,8 @@ private void mergeCollapse() { private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) { + if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; - } mergeAt(n); } } @@ -516,9 +508,8 @@ private void mergeAt(int i) { assert k >= 0; base1 += k; len1 -= k; - if (len1 == 0) { + if (len1 == 0) return; - } /* * Find where the last element of run1 goes in run2. Subsequent elements @@ -526,16 +517,14 @@ private void mergeAt(int i) { */ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; - if (len2 == 0) { + if (len2 == 0) return; - } // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) { + if (len1 <= len2) mergeLo(base1, len1, base2, len2); - } else { + else mergeHi(base1, len1, base2, len2); - } } /** @@ -568,13 +557,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base lastOfs += hint; @@ -585,13 +572,11 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to base int tmp = lastOfs; @@ -609,11 +594,10 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key0)) > 0) { + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) lastOfs = m + 1; // a[base + m] < key - } else { + else ofs = m; // key <= a[base + m] - } } assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] return ofs; @@ -645,13 +629,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b int tmp = lastOfs; @@ -663,13 +645,11 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; - if (ofs <= 0) { // int overflow + if (ofs <= 0) // int overflow ofs = maxOfs; - } } - if (ofs > maxOfs) { + if (ofs > maxOfs) ofs = maxOfs; - } // Make offsets relative to b lastOfs += hint; @@ -686,11 +666,10 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m, key1)) < 0) { + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) ofs = m; // key < a[b + m] - } else { + else lastOfs = m + 1; // a[b + m] <= key - } } assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] return ofs; @@ -756,16 +735,14 @@ private void mergeLo(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; - if (--len2 == 0) { + if (--len2 == 0) break outer; - } } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; - if (--len1 == 1) { + if (--len1 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -782,14 +759,12 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count1; cursor1 += count1; len1 -= count1; - if (len1 <= 1) { // len1 == 1 || len1 == 0 + if (len1 <= 1) // len1 == 1 || len1 == 0 break outer; - } } s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) { + if (--len2 == 0) break outer; - } count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { @@ -797,19 +772,16 @@ private void mergeLo(int base1, int len1, int base2, int len2) { dest += count2; cursor2 += count2; len2 -= count2; - if (len2 == 0) { + if (len2 == 0) break outer; - } } s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) { + if (--len1 == 1) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -885,16 +857,14 @@ private void mergeHi(int base1, int len1, int base2, int len2) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; - if (--len1 == 0) { + if (--len1 == 0) break outer; - } } else { s.copyElement(tmp, cursor2--, a, dest--); count2++; count1 = 0; - if (--len2 == 1) { + if (--len2 == 1) break outer; - } } } while ((count1 | count2) < minGallop); @@ -911,14 +881,12 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor1 -= count1; len1 -= count1; s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) { + if (len1 == 0) break outer; - } } s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) { + if (--len2 == 1) break outer; - } count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { @@ -926,19 +894,16 @@ private void mergeHi(int base1, int len1, int base2, int len2) { cursor2 -= count2; len2 -= count2; s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) { // len2 == 1 || len2 == 0 + if (len2 <= 1) // len2 == 1 || len2 == 0 break outer; - } } s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) { + if (--len1 == 0) break outer; - } minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) { + if (minGallop < 0) minGallop = 0; - } minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field @@ -978,11 +943,10 @@ private Buffer ensureCapacity(int minCapacity) { newSize |= newSize >> 16; newSize++; - if (newSize < 0) { // Not bloody likely! + if (newSize < 0) // Not bloody likely! newSize = minCapacity; - } else { + else newSize = Math.min(newSize, aLength >>> 1); - } tmp = s.allocate(newSize); tmpLength = newSize; From ef4819b1922fa9cd85cbf383e0788c86c90b2869 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Thu, 19 Nov 2015 16:20:34 -0500 Subject: [PATCH 148/149] [SPARK-6990] Disable MissingSwitchDefault check Checks fails in UnsafeRowParquetRecordReader.java. Let's enable the check in a separate change. --- checkstyle.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/checkstyle.xml b/checkstyle.xml index a63b1b7797201..a493ee443c752 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -87,7 +87,6 @@ - From d14f7536110a3ea1bf7c13769ffc4f478395db23 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Thu, 19 Nov 2015 16:25:24 -0500 Subject: [PATCH 149/149] [SPARK-6990] Fix qualifier order in SpecificParquetRecordReaderBase --- .../datasources/parquet/SpecificParquetRecordReaderBase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 2ed30c1f5a8d9..842dcb8c93dc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -195,7 +195,7 @@ protected static final class NullIntIterator extends IntIterator { * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator();