diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java new file mode 100644 index 0000000000000..fbc5666959055 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.scheduler.*; + +/** + * Class that allows users to receive all SparkListener events. + * Users should override the onEvent method. + * + * This is a concrete Java class in order to ensure that we don't forget to update it when adding + * new methods to SparkListener: forgetting to add a method will result in a compilation error (if + * this was a concrete Scala class, default implementations of new event handlers would be inherited + * from the SparkListener trait). + */ +public class SparkFirehoseListener implements SparkListener { + + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java deleted file mode 100644 index 095f9fb94fdf0..0000000000000 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import java.io.Serializable; - -import scala.Function0; -import scala.Function1; -import scala.Unit; - -import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.util.TaskCompletionListener; - -/** - * Contextual information about a task which can be read or mutated during - * execution. To access the TaskContext for a running task use - * TaskContext.get(). - */ -public abstract class TaskContext implements Serializable { - /** - * Return the currently active TaskContext. This can be called inside of - * user functions to access contextual information about running tasks. - */ - public static TaskContext get() { - return taskContext.get(); - } - - private static ThreadLocal taskContext = - new ThreadLocal(); - - static void setTaskContext(TaskContext tc) { - taskContext.set(tc); - } - - static void unset() { - taskContext.remove(); - } - - /** - * Whether the task has completed. - */ - public abstract boolean isCompleted(); - - /** - * Whether the task has been killed. - */ - public abstract boolean isInterrupted(); - - /** @deprecated use {@link #isRunningLocally()} */ - @Deprecated - public abstract boolean runningLocally(); - - public abstract boolean isRunningLocally(); - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(final Function1 f); - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * - * @deprecated use {@link #addTaskCompletionListener(scala.Function1)} - * - * @param f Callback function. - */ - @Deprecated - public abstract void addOnCompleteCallback(final Function0 f); - - /** - * The ID of the stage that this task belong to. - */ - public abstract int stageId(); - - /** - * The ID of the RDD partition that is computed by this task. - */ - public abstract int partitionId(); - - /** - * How many times this task has been attempted. The first task attempt will be assigned - * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. - */ - public abstract int attemptNumber(); - - /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */ - @Deprecated - public abstract long attemptId(); - - /** - * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts - * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. - */ - public abstract long taskAttemptId(); - - /** ::DeveloperApi:: */ - @DeveloperApi - public abstract TaskMetrics taskMetrics(); -} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 228076f01c841..6a16a31654630 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -804,6 +804,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { assertNotStopped() + // The call to new NewHadoopJob automatically adds security credentials to conf, + // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -826,7 +828,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { assertNotStopped() - new NewHadoopRDD(this, fClass, kClass, vClass, conf) + // Add necessary security credentials to the JobConf. Required to access secure HDFS. + val jconf = new JobConf(conf) + SparkHadoopUtil.get.addCredentials(jconf) + new NewHadoopRDD(this, fClass, kClass, vClass, jconf) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala new file mode 100644 index 0000000000000..7d7fe1a446313 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.Serializable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + + +object TaskContext { + /** + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. + */ + def get(): TaskContext = taskContext.get + + private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + + // Note: protected[spark] instead of private[spark] to prevent the following two from + // showing up in JavaDoc. + /** + * Set the thread local TaskContext. Internal to Spark. + */ + protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc) + + /** + * Unset the thread local TaskContext. Internal to Spark. + */ + protected[spark] def unset(): Unit = taskContext.remove() +} + + +/** + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task, use: + * {{{ + * org.apache.spark.TaskContext.get() + * }}} + */ +abstract class TaskContext extends Serializable { + // Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler + // from generating a static get method (based on the companion object's get method). + + // Note: Update JavaTaskContextCompileCheck when new methods are added to this class. + + // Note: getters in this class are defined with parentheses to maintain backward compatibility. + + /** + * Returns true if the task has completed. + */ + def isCompleted(): Boolean + + /** + * Returns true if the task has been killed. + */ + def isInterrupted(): Boolean + + @deprecated("use isRunningLocally", "1.2.0") + def runningLocally(): Boolean + + /** + * Returns true if the task is running locally in the driver program. + * @return + */ + def isRunningLocally(): Boolean + + /** + * Adds a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext + + /** + * Adds a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + + /** + * Adds a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * @param f Callback function. + */ + @deprecated("use addTaskCompletionListener", "1.2.0") + def addOnCompleteCallback(f: () => Unit) + + /** + * The ID of the stage that this task belong to. + */ + def stageId(): Int + + /** + * The ID of the RDD partition that is computed by this task. + */ + def partitionId(): Int + + /** + * How many times this task has been attempted. The first task attempt will be assigned + * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + */ + def attemptNumber(): Int + + @deprecated("use attemptNumber", "1.3.0") + def attemptId(): Long + + /** + * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts + * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. + */ + def taskAttemptId(): Long + + /** ::DeveloperApi:: */ + @DeveloperApi + def taskMetrics(): TaskMetrics +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9bb0c61e441f8..337c8e4ebebcd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -33,7 +33,7 @@ private[spark] class TaskContextImpl( with Logging { // For backwards-compatibility; this method is now deprecated as of 1.3.0. - override def attemptId: Long = taskAttemptId + override def attemptId(): Long = taskAttemptId // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -87,10 +87,10 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted: Boolean = completed + override def isCompleted(): Boolean = completed - override def isRunningLocally: Boolean = runningLocally + override def isRunningLocally(): Boolean = runningLocally - override def isInterrupted: Boolean = interrupted + override def isInterrupted(): Boolean = interrupted } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 33a7aae5d3fcd..79f84e70df9d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -361,7 +361,7 @@ private[spark] class TaskSchedulerImpl( dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) } - def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c94c6bbcb37b..97c22fe724abd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -542,7 +542,7 @@ private[spark] class TaskSetManager( /** * Check whether has enough quota to fetch the result with `size` bytes */ - def canFetchMoreResults(size: Long): Boolean = synchronized { + def canFetchMoreResults(size: Long): Boolean = sched.synchronized { totalResultSize += size calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { @@ -671,7 +671,7 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - def abort(message: String) { + def abort(message: String): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.dagScheduler.taskSetFailed(taskSet, message) isZombie = true diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java similarity index 93% rename from core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java rename to core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java index e9ec700e32e15..e38bc38949d7c 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.util; +package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; /** diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java new file mode 100644 index 0000000000000..4a918f725dc91 --- /dev/null +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark; + +import org.apache.spark.TaskContext; + +/** + * Something to make sure that TaskContext can be used in Java. + */ +public class JavaTaskContextCompileCheck { + + public static void test() { + TaskContext tc = TaskContext.get(); + + tc.isCompleted(); + tc.isInterrupted(); + tc.isRunningLocally(); + + tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + + tc.attemptNumber(); + tc.partitionId(); + tc.stageId(); + tc.taskAttemptId(); + } +} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index be8c5c2c1522e..22664b419f5cb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -580,6 +580,15 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. + + spark.sql.parquet.int96AsTimestamp + true + + Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also + store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. + + spark.sql.parquet.cacheMetadata true diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py new file mode 100644 index 0000000000000..a2cd626c9f19d --- /dev/null +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A Gaussian Mixture Model clustering program using MLlib. +""" +import sys +import random +import argparse +import numpy as np + +from pyspark import SparkConf, SparkContext +from pyspark.mllib.clustering import GaussianMixture + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +if __name__ == "__main__": + """ + Parameters + ---------- + :param inputFile: Input file path which contains data points + :param k: Number of mixture components + :param convergenceTol: Convergence threshold. Default to 1e-3 + :param maxIterations: Number of EM iterations to perform. Default to 100 + :param seed: Random seed + """ + + parser = argparse.ArgumentParser() + parser.add_argument('inputFile', help='Input File') + parser.add_argument('k', type=int, help='Number of clusters') + parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') + parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') + parser.add_argument('--seed', default=random.getrandbits(19), + type=long, help='Random seed') + args = parser.parse_args() + + conf = SparkConf().setAppName("GMM") + sc = SparkContext(conf=conf) + + lines = sc.textFile(args.inputFile) + data = lines.map(parseVector) + model = GaussianMixture.train(data, args.k, args.convergenceTol, + args.maxIterations, args.seed) + for i in range(args.k): + print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray()) + print ("Cluster labels (first 100): ", model.predict(data).take(100)) + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala new file mode 100644 index 0000000000000..f4c545ad70e96 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.text.BreakIterator + +import scala.collection.mutable + +import scopt.OptionParser + +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD + + +/** + * An example Latent Dirichlet Allocation (LDA) app. Run with + * {{{ + * ./bin/run-example mllib.LDAExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LDAExample { + + private case class Params( + input: Seq[String] = Seq.empty, + k: Int = 20, + maxIterations: Int = 10, + docConcentration: Double = -1, + topicConcentration: Double = -1, + vocabSize: Int = 10000, + stopwordFile: String = "", + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LDAExample") { + head("LDAExample: an example LDA app for plain text data.") + opt[Int]("k") + .text(s"number of topics. default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]("maxIterations") + .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]("docConcentration") + .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.docConcentration}") + .action((x, c) => c.copy(docConcentration = x)) + opt[Double]("topicConcentration") + .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.topicConcentration}") + .action((x, c) => c.copy(topicConcentration = x)) + opt[Int]("vocabSize") + .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + + s" default: ${defaultParams.vocabSize}") + .action((x, c) => c.copy(vocabSize = x)) + opt[String]("stopwordFile") + .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + + s" default: ${defaultParams.stopwordFile}") + .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("checkpointDir") + .text(s"Directory for checkpointing intermediate results." + + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + + s" default: ${defaultParams.checkpointDir}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + + s" default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + arg[String]("...") + .text("input paths (directories) to plain text corpora." + + " Each text file line should hold 1 document.") + .unbounded() + .required() + .action((x, c) => c.copy(input = c.input :+ x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + parser.showUsageAsError + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"LDAExample with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + // Load documents, and prepare them for LDA. + val preprocessStart = System.nanoTime() + val (corpus, vocabArray, actualNumTokens) = + preprocess(sc, params.input, params.vocabSize, params.stopwordFile) + corpus.cache() + val actualCorpusSize = corpus.count() + val actualVocabSize = vocabArray.size + val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 + + println() + println(s"Corpus summary:") + println(s"\t Training set size: $actualCorpusSize documents") + println(s"\t Vocabulary size: $actualVocabSize terms") + println(s"\t Training set size: $actualNumTokens tokens") + println(s"\t Preprocessing time: $preprocessElapsed sec") + println() + + // Run LDA. + val lda = new LDA() + lda.setK(params.k) + .setMaxIterations(params.maxIterations) + .setDocConcentration(params.docConcentration) + .setTopicConcentration(params.topicConcentration) + .setCheckpointInterval(params.checkpointInterval) + if (params.checkpointDir.nonEmpty) { + lda.setCheckpointDir(params.checkpointDir.get) + } + val startTime = System.nanoTime() + val ldaModel = lda.run(corpus) + val elapsed = (System.nanoTime() - startTime) / 1e9 + + println(s"Finished training LDA model. Summary:") + println(s"\t Training time: $elapsed sec") + val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + + // Print the topics, showing the top-weighted terms for each topic. + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } + } + println(s"${params.k} topics:") + topics.zipWithIndex.foreach { case (topic, i) => + println(s"TOPIC $i") + topic.foreach { case (term, weight) => + println(s"$term\t$weight") + } + println() + } + + } + + /** + * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. + * @return (corpus, vocabulary as array, total token count in corpus) + */ + private def preprocess( + sc: SparkContext, + paths: Seq[String], + vocabSize: Int, + stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + + // Get dataset of document texts + // One document per line in each text file. + val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) + + // Split text into words + val tokenizer = new SimpleTokenizer(sc, stopwordFile) + val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => + id -> tokenizer.getWords(text) + } + tokenized.cache() + + // Counts words: RDD[(word, wordCount)] + val wordCounts: RDD[(String, Long)] = tokenized + .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } + .reduceByKey(_ + _) + wordCounts.cache() + val fullVocabSize = wordCounts.count() + // Select vocab + // (vocab: Map[word -> id], total tokens after selecting vocab) + val (vocab: Map[String, Int], selectedTokenCount: Long) = { + val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocabSize) + } + (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) + } + + val documents = tokenized.map { case (id, tokens) => + // Filter tokens by vocabulary, and create word count vector representation of document. + val wc = new mutable.HashMap[Int, Int]() + tokens.foreach { term => + if (vocab.contains(term)) { + val termIndex = vocab(term) + wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 + } + } + val indices = wc.keys.toArray.sorted + val values = indices.map(i => wc(i).toDouble) + + val sb = Vectors.sparse(vocab.size, indices, values) + (id, sb) + } + + val vocabArray = new Array[String](vocab.size) + vocab.foreach { case (term, i) => vocabArray(i) = term } + + (documents, vocabArray, selectedTokenCount) + } +} + +/** + * Simple Tokenizer. + * + * TODO: Formalize the interface, and make this a public class in mllib.feature + */ +private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { + + private val stopwords: Set[String] = if (stopwordFile.isEmpty) { + Set.empty[String] + } else { + val stopwordText = sc.textFile(stopwordFile).collect() + stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet + } + + // Matches sequences of Unicode letters + private val allWordRegex = "^(\\p{L}*)$".r + + // Ignore words shorter than this length. + private val minWordLength = 3 + + def getWords(text: String): IndexedSeq[String] = { + + val words = new mutable.ArrayBuffer[String]() + + // Use Java BreakIterator to tokenize text into words. + val wb = BreakIterator.getWordInstance + wb.setText(text) + + // current,end index start,end of each word + var current = wb.first() + var end = wb.next() + while (end != BreakIterator.DONE) { + // Convert to lowercase + val word: String = text.substring(current, end).toLowerCase + // Remove short words and strings that aren't only letters + word match { + case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => + words += w + case _ => + } + + current = end + try { + end = wb.next() + } catch { + case e: Exception => + // Ignore remaining text in line. + // This is a known bug in BreakIterator (for some Java versions), + // which fails when it sees certain characters. + end = BreakIterator.DONE + } + } + words + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index c5bd5b0b178d9..1a95048bbfe2d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -35,8 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} * * To run on your local machine using the two directories `trainingDir` and `testDir`, * with updates every 5 seconds, and 2 features per data point, call: - * $ bin/run-example \ - * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2 + * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2 * * As you add text files to `trainingDir` the model will continuously update. * Anytime you add text files to `testDir`, you'll see predictions from the current model. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala new file mode 100644 index 0000000000000..e1998099c2d78 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Train a logistic regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features, y is a binary label, and + * n must be the same for train and test. + * + * Usage: StreamingLogisticRegression + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example mllib.StreamingLogisticRegression trainingDir testDir 5 2 + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ +object StreamingLogisticRegression { + + def main(args: Array[String]) { + + if (args.length != 4) { + System.err.println( + "Usage: StreamingLogisticRegression ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLogisticRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.zeros(args(3).toInt)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + + } + +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index b19c053ebfc44..0817c56d8f39f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -164,7 +164,7 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin } private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert( server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), s"Partition [$topic, $partition] metadata not propagated after timeout" diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 82d21d5e4cb6e..511cb2fe4005e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -22,6 +22,7 @@ import java.{util => ju} import scala.collection.mutable import scala.reflect.ClassTag import scala.util.Sorting +import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} @@ -37,6 +38,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -412,7 +414,7 @@ object ALS extends Logging { /** * Implementation of the ALS algorithm. */ - def train[ID: ClassTag]( + def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, numUserBlocks: Int = 10, @@ -421,34 +423,47 @@ object ALS extends Logging { regParam: Double = 1.0, implicitPrefs: Boolean = false, alpha: Double = 1.0, - nonnegative: Boolean = false)( + nonnegative: Boolean = false, + intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(intermediateRDDStorageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext val userPart = new HashPartitioner(numUserBlocks) val itemPart = new HashPartitioner(numItemBlocks) val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) val solver = if (nonnegative) new NNLSSolver else new CholeskySolver - val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() - val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + val blockRatings = partitionRatings(ratings, userPart, itemPart) + .persist(intermediateRDDStorageLevel) + val (userInBlocks, userOutBlocks) = + makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) // materialize blockRatings and user blocks userOutBlocks.count() val swappedBlockRatings = blockRatings.map { case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) } - val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + val (itemInBlocks, itemOutBlocks) = + makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) // materialize item blocks itemOutBlocks.count() - var userFactors = initialize(userInBlocks, rank) - var itemFactors = initialize(itemInBlocks, rank) + val seedGen = new XORShiftRandom(seed) + var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) + var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) if (implicitPrefs) { for (iter <- 1 to maxIter) { - userFactors.setName(s"userFactors-$iter").persist() + userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) val previousItemFactors = itemFactors itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, implicitPrefs, alpha, solver) previousItemFactors.unpersist() - itemFactors.setName(s"itemFactors-$iter").persist() + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + itemFactors.checkpoint() + } + itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, implicitPrefs, alpha, solver) @@ -467,21 +482,23 @@ object ALS extends Logging { .join(userFactors) .values .setName("userFactors") - .cache() - userIdAndFactors.count() - itemFactors.unpersist() + .persist(finalRDDStorageLevel) val itemIdAndFactors = itemInBlocks .mapValues(_.srcIds) .join(itemFactors) .values .setName("itemFactors") - .cache() - itemIdAndFactors.count() - userInBlocks.unpersist() - userOutBlocks.unpersist() - itemInBlocks.unpersist() - itemOutBlocks.unpersist() - blockRatings.unpersist() + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userIdAndFactors.count() + itemFactors.unpersist() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + } val userOutput = userIdAndFactors.flatMap { case (ids, factors) => ids.view.zip(factors) } @@ -546,14 +563,15 @@ object ALS extends Logging { */ private def initialize[ID]( inBlocks: RDD[(Int, InBlock[ID])], - rank: Int): RDD[(Int, FactorBlock)] = { + rank: Int, + seed: Long): RDD[(Int, FactorBlock)] = { // Choose a unit vector uniformly at random from the unit sphere, but from the // "first quadrant" where all elements are nonnegative. This can be done by choosing // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. // This appears to create factorizations that have a slightly better reconstruction // (<1%) compared picking elements uniformly at random in [0,1]. inBlocks.map { case (srcBlockId, inBlock) => - val random = new XORShiftRandom(srcBlockId) + val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId)) val factors = Array.fill(inBlock.srcIds.length) { val factor = Array.fill(rank)(random.nextGaussian().toFloat) val nrm = blas.snrm2(rank, factor, 1) @@ -877,7 +895,8 @@ object ALS extends Logging { prefix: String, ratingBlocks: RDD[((Int, Int), RatingBlock[ID])], srcPart: Partitioner, - dstPart: Partitioner)( + dstPart: Partitioner, + storageLevel: StorageLevel)( implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = { val inBlocks = ratingBlocks.map { case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => @@ -914,7 +933,8 @@ object ALS extends Logging { builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) } builder.build().compress() - }.setName(prefix + "InBlocks").cache() + }.setName(prefix + "InBlocks") + .persist(storageLevel) val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => val encoder = new LocalIndexEncoder(dstPart.numPartitions) val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) @@ -936,7 +956,8 @@ object ALS extends Logging { activeIds.map { x => x.result() } - }.setName(prefix + "OutBlocks").cache() + }.setName(prefix + "OutBlocks") + .persist(storageLevel) (inBlocks, outBlocks) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a66d6f0cf29c7..980980593d194 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import scala.reflect.ClassTag @@ -40,6 +41,7 @@ import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} @@ -260,7 +262,7 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for Python mllib KMeans.train() + * Java stub for Python mllib KMeans.run() */ def trainKMeansModel( data: JavaRDD[Vector], @@ -284,6 +286,58 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GaussianMixture.run() + * Returns a list containing weights, mean and covariance of each mixture component. + */ + def trainGaussianMixture( + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, + maxIterations: Int, + seed: Long): JList[Object] = { + val gmmAlg = new GaussianMixture() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) + + if (seed != null) gmmAlg.setSeed(seed) + + try { + val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + var wt = ArrayBuffer.empty[Double] + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + for (i <- 0 until model.k) { + wt += model.weights(i) + mu += model.gaussians(i).mu + sigma += model.gaussians(i).sigma + } + List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Java stub for Python mllib GaussianMixtureModel.predictSoft() + */ + def predictSoftGMM( + data: JavaRDD[Vector], + wt: Object, + mu: Array[Object], + si: Array[Object]): RDD[Array[Double]] = { + + val weight = wt.asInstanceOf[Array[Double]] + val mean = mu.map(_.asInstanceOf[DenseVector]) + val sigma = si.map(_.asInstanceOf[DenseMatrix]) + val gaussians = Array.tabulate(weight.length){ + i => new MultivariateGaussian(mean(i), sigma(i)) + } + val model = new GaussianMixtureModel(weight, gaussians) + model.predictSoft(data) + } + /** * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 282fb3ff283f4..a469315a1b5c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -136,7 +136,7 @@ class LogisticRegressionModel ( * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ -class LogisticRegressionWithSGD private ( +class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var regParam: Double, @@ -158,7 +158,7 @@ class LogisticRegressionWithSGD private ( */ def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala new file mode 100644 index 0000000000000..6a3893d0e41d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.StreamingLinearAlgorithm + +/** + * :: Experimental :: + * Train or predict a logistic regression model on streaming data. Training uses + * Stochastic Gradient Descent to update the model based on each new batch of + * incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation) + * + * Each batch of data is assumed to be an RDD of LabeledPoints. + * The number of data points per batch can vary, but the number + * of features must be constant. An initial weight + * vector must be provided. + * + * Use a builder pattern to construct a streaming logistic regression + * analysis in an application, like: + * + * val model = new StreamingLogisticRegressionWithSGD() + * .setStepSize(0.5) + * .setNumIterations(10) + * .setInitialWeights(Vectors.dense(...)) + * .trainOn(DStream) + * + */ +@Experimental +class StreamingLogisticRegressionWithSGD private[mllib] ( + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double, + private var regParam: Double) + extends StreamingLinearAlgorithm[LogisticRegressionModel, LogisticRegressionWithSGD] + with Serializable { + + /** + * Construct a StreamingLogisticRegression object with default parameters: + * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0, regParam: 0.0}. + * Initial weights must be set before using trainOn or predictOn + * (see `StreamingLinearAlgorithm`) + */ + def this() = this(0.1, 50, 1.0, 0.0) + + val algorithm = new LogisticRegressionWithSGD( + stepSize, numIterations, regParam, miniBatchFraction) + + /** Set the step size for gradient descent. Default: 0.1. */ + def setStepSize(stepSize: Double): this.type = { + this.algorithm.optimizer.setStepSize(stepSize) + this + } + + /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + def setNumIterations(numIterations: Int): this.type = { + this.algorithm.optimizer.setNumIterations(numIterations) + this + } + + /** Set the fraction of each batch to use for updates. Default: 1.0. */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) + this + } + + /** Set the regularization parameter. Default: 0.0. */ + def setRegParam(regParam: Double): this.type = { + this.algorithm.optimizer.setRegParam(regParam) + this + } + + /** Set the initial weights. Default: [0.0, 0.0]. */ + def setInitialWeights(initialWeights: Vector): this.type = { + this.model = Some(algorithm.createModel(initialWeights, 0.0)) + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 0000000000000..d8f82867a09d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "word" = "term": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over words representing some concept + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + */ +@Experimental +class LDA private ( + private var k: Int, + private var maxIterations: Int, + private var docConcentration: Double, + private var topicConcentration: Double, + private var seed: Long, + private var checkpointDir: Option[String], + private var checkpointInterval: Int) extends Logging { + + def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, + seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10) + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + */ + def getK: Int = k + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + * (default = 10) + */ + def setK(k: Int): this.type = { + require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") + this.k = k + this + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + */ + def getDocConcentration: Double = { + if (this.docConcentration == -1) { + (50.0 / k) + 1.0 + } else { + this.docConcentration + } + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * This value should be > 1.0, where larger values mean more smoothing (more regularization). + * If set to -1, then docConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = (50 / k) + 1. + * - The 50/k is common in LDA libraries. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setDocConcentration(docConcentration: Double): this.type = { + require(docConcentration > 1.0 || docConcentration == -1.0, + s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") + this.docConcentration = docConcentration + this + } + + /** Alias for [[getDocConcentration]] */ + def getAlpha: Double = getDocConcentration + + /** Alias for [[setDocConcentration()]] */ + def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def getTopicConcentration: Double = { + if (this.topicConcentration == -1) { + 1.1 + } else { + this.topicConcentration + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * This value should be > 0.0. + * If set to -1, then topicConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = 0.1 + 1. + * - The 0.1 gives a small amount of smoothing. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setTopicConcentration(topicConcentration: Double): this.type = { + require(topicConcentration > 1.0 || topicConcentration == -1.0, + s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") + this.topicConcentration = topicConcentration + this + } + + /** Alias for [[getTopicConcentration]] */ + def getBeta: Double = getTopicConcentration + + /** Alias for [[setTopicConcentration()]] */ + def setBeta(beta: Double): this.type = setBeta(beta) + + /** + * Maximum number of iterations for learning. + */ + def getMaxIterations: Int = maxIterations + + /** + * Maximum number of iterations for learning. + * (default = 20) + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Random seed */ + def getSeed: Long = seed + + /** Random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Directory for storing checkpoint files during learning. + * This is not necessary, but checkpointing helps with recovery (when nodes fail). + * It also helps with eliminating temporary shuffle files on disk, which can be important when + * LDA is run for many iterations. + */ + def getCheckpointDir: Option[String] = checkpointDir + + /** + * Directory for storing checkpoint files during learning. + * This is not necessary, but checkpointing helps with recovery (when nodes fail). + * It also helps with eliminating temporary shuffle files on disk, which can be important when + * LDA is run for many iterations. + * + * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value + * given to LDA is ignored, and the existing directory is kept. + * + * (default = None) + */ + def setCheckpointDir(checkpointDir: String): this.type = { + this.checkpointDir = Some(checkpointDir) + this + } + + /** + * Clear the directory for storing checkpoint files during learning. + * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still + * occur; otherwise, no checkpointing will be used. + */ + def clearCheckpointDir(): this.type = { + this.checkpointDir = None + this + } + + /** + * Period (in iterations) between checkpoints. + * @see [[getCheckpointDir]] + */ + def getCheckpointInterval: Int = checkpointInterval + + /** + * Period (in iterations) between checkpoints. + * (default = 10) + * @see [[getCheckpointDir]] + */ + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + + /** + * Learn an LDA model using the given dataset. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * Document IDs must be unique and >= 0. + * @return Inferred LDA model + */ + def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointDir, checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 + } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + } + + /** Java-friendly version of [[run()]] */ + def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } +} + + +private[clustering] object LDA { + + /* + DEVELOPERS NOTE: + + This implementation uses GraphX, where the graph is bipartite with 2 types of vertices: + - Document vertices + - indexed with unique indices >= 0 + - Store vectors of length k (# topics). + - Term vertices + - indexed {-1, -2, ..., -vocabSize} + - Store vectors of length k (# topics). + - Edges correspond to terms appearing in documents. + - Edges are directed Document -> Term. + - Edges are partitioned by documents. + + Info on EM implementation. + - We follow Section 2.2 from Asuncion et al., 2009. We use some of their notation. + - In this implementation, there is one edge for every unique term appearing in a document, + i.e., for every unique (document, term) pair. + - Notation: + - N_{wkj} = count of tokens of term w currently assigned to topic k in document j + - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s) + - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j), + the probability of term x_i in document d_i having topic z_i. + - Data graph + - Document vertices store N_{kj} + - Term vertices store N_{wk} + - Edges store N_{wj}. + - Global data N_k + - Algorithm + - Initial state: + - Document and term vertices store random counts N_{wk}, N_{kj}. + - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i). + - Aggregate N_k from term vertices. + - Compute gamma_{wjk} for each possible topic k, from each triplet. + using inputs N_{wk}, N_{kj}, N_k. + - M-step: Compute sufficient statistics for hidden parameters phi and theta + (counts N_{wk}, N_{kj}, N_k). + - Document update: + - N_{kj} <- sum_w N_{wj} gamma_{wjk} + - N_j <- sum_k N_{kj} (only needed to output predictions) + - Term update: + - N_{wk} <- sum_j N_{wj} gamma_{wjk} + - N_k <- sum_w N_{wk} + + TODO: Add simplex constraints to allow alpha in (0,1). + See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive + Regularization for Stochastic Matrix Factorization." 2014. + */ + + /** + * Vector over topics (length k) of token counts. + * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. + */ + type TopicCounts = BDV[Double] + + type TokenCount = Double + + /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ + def term2index(term: Int): Long = -(1 + term.toLong) + + def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + + def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + + def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + + /** + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * @param graph EM graph, storing current parameter estimates in vertex descriptors and + * data (token counts) in edge descriptors. + * @param k Number of topics + * @param vocabSize Number of unique terms + * @param docConcentration "alpha" + * @param topicConcentration "beta" or "eta" + */ + class EMOptimizer( + var graph: Graph[TopicCounts, TokenCount], + val k: Int, + val vocabSize: Int, + val docConcentration: Double, + val topicConcentration: Double, + checkpointDir: Option[String], + checkpointInterval: Int) { + + private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + graph, checkpointDir, checkpointInterval) + + def next(): EMOptimizer = { + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + } + + /** + * Compute gamma_{wjk}, a distribution over topics k. + */ + private def computePTopic( + docTopicCounts: TopicCounts, + termTopicCounts: TopicCounts, + totalTopicCounts: TopicCounts, + vocabSize: Int, + eta: Double, + alpha: Double): TopicCounts = { + val K = docTopicCounts.length + val N_j = docTopicCounts.data + val N_w = termTopicCounts.data + val N = totalTopicCounts.data + val eta1 = eta - 1.0 + val alpha1 = alpha - 1.0 + val Weta1 = vocabSize * eta1 + var sum = 0.0 + val gamma_wj = new Array[Double](K) + var k = 0 + while (k < K) { + val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) + gamma_wj(k) = gamma_wjk + sum += gamma_wjk + k += 1 + } + // normalize + BDV(gamma_wj) /= sum + } + + /** + * Compute bipartite term/doc graph. + */ + private def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointDir: Option[String], + checkpointInterval: Int): EMOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.map { edge => + // Create a random gamma_{wjk} + (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)) + } + } + def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] = + edgesWithGamma.map { case (edge, gamma: TopicCounts) => + (sendToWhere(edge), (edge.attr, gamma)) + } + verticesTMP.aggregateByKey(BDV.zeros[Double](k))( + (sum, t) => { + brzAxpy(t._1, t._2, sum) + sum + }, + (sum0, sum1) => { + sum0 += sum1 + } + ) + } + val docVertices = createVertices(_.srcId) + val termVertices = createVertices(_.dstId) + + // Partition such that edges are grouped by document + val graph = Graph(docVertices ++ termVertices, edges) + .partitionBy(PartitionStrategy.EdgePartition1D) + + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir, + checkpointInterval) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 0000000000000..19e8aab6eabd7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA) model. + * + * This abstraction permits for different underlying representations, + * including local and distributed data structures. + */ +@Experimental +abstract class LDAModel private[clustering] { + + /** Number of topics */ + def k: Int + + /** Vocabulary size (number of terms or terms in the vocabulary) */ + def vocabSize: Int + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + */ + def topicsMatrix: Matrix + + /** + * Return the topics described by weighted terms. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] + + /** + * Return the topics described by weighted terms. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = + // describeTopicsAsStrings(vocabSize) + + /* TODO + * Compute the log likelihood of the observed tokens, given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior. + * - Even with the prior, this is NOT the same as the data log likelihood given the + * hyperparameters. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated log likelihood of the data under this model + */ + // def logLikelihood(documents: RDD[(Long, Vector)]): Double + + /* TODO + * Compute the estimated topic distribution for each document. + * This is often called “theta” in the literature. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated topic distribution for each document. + * The returned RDD may be zipped with the given RDD, where each returned vector + * is a multinomial distribution over topics. + */ + // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] + +} + +/** + * :: Experimental :: + * + * Local LDA model. + * This model stores only the inferred topics. + * It may be used for computing topics for new documents, but it may give less accurate answers + * than the [[DistributedLDAModel]]. + * + * @param topics Inferred topics (vocabSize x k matrix). + */ +@Experimental +class LocalLDAModel private[clustering] ( + private val topics: Matrix) extends LDAModel with Serializable { + + override def k: Int = topics.numCols + + override def vocabSize: Int = topics.numRows + + override def topicsMatrix: Matrix = topics + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val brzTopics = topics.toBreeze.toDenseMatrix + Range(0, k).map { topicIndex => + val topic = normalize(brzTopics(::, topicIndex), 1.0) + val (termWeights, terms) = + topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip + (terms.toArray, termWeights.toArray) + }.toArray + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} + +/** + * :: Experimental :: + * + * Distributed LDA model. + * This model stores the inferred topics, the full training dataset, and the topic distributions. + * When computing topics for new documents, it may give more accurate answers + * than the [[LocalLDAModel]]. + */ +@Experimental +class DistributedLDAModel private ( + private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private val globalTopicTotals: LDA.TopicCounts, + val k: Int, + val vocabSize: Int, + private val docConcentration: Double, + private val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double]) extends LDAModel { + + import LDA._ + + private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, + state.topicConcentration, iterationTimes) + } + + /** + * Convert model to a local model. + * The local model stores the inferred topics but not the topic distributions for training + * documents. + */ + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. + */ + override lazy val topicsMatrix: Matrix = { + // Collect row-major topics + val termTopicCounts: Array[(Int, TopicCounts)] = + graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => + (index2term(termIndex), cnts) + }.collect() + // Convert to Matrix + val brzTopics = BDM.zeros[Double](vocabSize, k) + termTopicCounts.foreach { case (term, cnts) => + var j = 0 + while (j < k) { + brzTopics(term, j) = cnts(j) + j += 1 + } + } + Matrices.fromBreeze(brzTopics) + } + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val numTopics = k + // Note: N_k is not needed to find the top terms, but it is needed to normalize weights + // to a distribution over terms. + val N_k: TopicCounts = globalTopicTotals + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } + } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 + } + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + */ + lazy val logLikelihood: Double = { + val eta = topicConcentration + val alpha = docConcentration + assert(eta > 1.0) + assert(alpha > 1.0) + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + // Edges: Compute token log probability from phi_{wk}, theta_{kj}. + val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { + val N_wj = edgeContext.attr + val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) + val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) + edgeContext.sendToDst(tokenLogLikelihood) + } + graph.aggregateMessages[Double](sendMsg, _ + _) + .map(_._2).fold(0.0)(_ + _) + } + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | alpha, eta) + */ + lazy val logPrior: Double = { + val eta = topicConcentration + val alpha = docConcentration + // Term vertices: Compute phi_{wk}. Use to compute prior log probability. + // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + val seqOp: (Double, (VertexId, TopicCounts)) => Double = { + case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => + if (isTermVertex(vertex)) { + val N_wk = vertex._2 + val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + (eta - 1.0) * brzSum(phi_wk.map(math.log)) + } else { + val N_kj = vertex._2 + val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + } + } + graph.vertices.aggregate(0.0)(seqOp, _ + _) + } + + /** + * For each document in the training set, return the distribution over topics for that document + * (i.e., "theta_doc"). + * + * @return RDD of (document ID, topic distribution) pairs + */ + def topicDistributions: RDD[(Long, Vector)] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) + } + } + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9591c7966e06a..1433ee9a0dd5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -18,14 +18,31 @@ package org.apache.spark.mllib.fpm import java.{util => ju} +import java.lang.{Iterable => JavaIterable} import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag -import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable +/** + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs + * @tparam Item item type + */ +class FPGrowthModel[Item: ClassTag]( + val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { + + /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */ + def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = { + JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]] + } +} /** * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. @@ -69,7 +86,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run(data: RDD[Array[String]]): FPGrowthModel = { + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -82,19 +99,24 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) + } + /** * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems( - data: RDD[Array[String]], + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - partitioner: Partitioner): Array[String] = { + partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.length != uniq.size) { + if (t.size != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t @@ -114,11 +136,11 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets( - data: RDD[Array[String]], + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - freqItems: Array[String], - partitioner: Partitioner): RDD[(Array[String], Long)] = { + freqItems: Array[Item], + partitioner: Partitioner): RDD[(Array[Item], Long)] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -139,9 +161,9 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions( - transaction: Array[String], - itemToRank: Map[String, Int], + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], + itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala new file mode 100644 index 0000000000000..76672fe51e834 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.graphx.Graph +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Graphs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are + * responsible for materializing the graph to ensure that persisting and checkpointing actually + * occur. + * + * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. + * - Unpersist graphs from queue until there are at most 3 persisted graphs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new graph, and put in a queue of checkpointed graphs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Graphs should be + * checkpointed). + * - This class removes checkpoint files once later graphs have been checkpointed. + * However, references to the older graphs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (graph1, graph2, graph3, ...) = ... + * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * graph1.vertices.count(); graph1.edges.count() + * // persisted: graph1 + * cp.updateGraph(graph2) + * graph2.vertices.count(); graph2.edges.count() + * // persisted: graph1, graph2 + * // checkpointed: graph2 + * cp.updateGraph(graph3) + * graph3.vertices.count(); graph3.edges.count() + * // persisted: graph1, graph2, graph3 + * // checkpointed: graph2 + * cp.updateGraph(graph4) + * graph4.vertices.count(); graph4.edges.count() + * // persisted: graph2, graph3, graph4 + * // checkpointed: graph4 + * cp.updateGraph(graph5) + * graph5.vertices.count(); graph5.edges.count() + * // persisted: graph3, graph4, graph5 + * // checkpointed: graph4 + * }}} + * + * @param currentGraph Initial graph + * @param checkpointDir The directory for storing checkpoint files + * @param checkpointInterval Graphs will be checkpointed at this interval + * @tparam VD Vertex descriptor type + * @tparam ED Edge descriptor type + * + * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + */ +private[mllib] class PeriodicGraphCheckpointer[VD, ED]( + var currentGraph: Graph[VD, ED], + val checkpointDir: Option[String], + val checkpointInterval: Int) extends Logging { + + /** FIFO queue of past checkpointed RDDs */ + private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() + + /** FIFO queue of past persisted RDDs */ + private val persistedQueue = mutable.Queue[Graph[VD, ED]]() + + /** Number of times [[updateGraph()]] has been called */ + private var updateCount = 0 + + /** + * Spark Context for the Graphs given to this checkpointer. + * NOTE: This code assumes that only one SparkContext is used for the given graphs. + */ + private val sc = currentGraph.vertices.sparkContext + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) { + sc.setCheckpointDir(checkpointDir.get) + } + + updateGraph(currentGraph) + + /** + * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the graph + * has been materialized. + * + * @param newGraph New graph created from previous graphs in the lineage. + */ + def updateGraph(newGraph: Graph[VD, ED]): Unit = { + if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { + newGraph.persist() + } + persistedQueue.enqueue(newGraph) + // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: + // Users should call [[updateGraph()]] when a new graph has been created, + // before the graph has been materialized. + while (persistedQueue.size > 3) { + val graphToUnpersist = persistedQueue.dequeue() + graphToUnpersist.unpersist(blocking = false) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + newGraph.checkpoint() + checkpointQueue.enqueue(newGraph) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (checkpointQueue.get(1).get.isCheckpointed) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + old.getCheckpointFiles.foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index a5ffe888ca880..f4f51f2ac5210 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,46 +17,12 @@ package org.apache.spark.mllib.recommendation -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.math.{abs, sqrt} -import scala.util.{Random, Sorting} -import scala.util.hashing.byteswap32 - -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - -import org.apache.spark.{HashPartitioner, Logging, Partitioner} -import org.apache.spark.SparkContext._ +import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mllib.optimization.NNLS +import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - -/** - * Out-link information for a user or product block. This includes the original user/product IDs - * of the elements within this block, and the list of destination blocks that each user or - * product will need to send its feature vector to. - */ -private[recommendation] -case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet]) - - -/** - * In-link information for a user (or product) block. This includes the original user/product IDs - * of the elements within this block, as well as an array of indices and ratings that specify - * which user in the block will be rated by which products from each product block (or vice-versa). - * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, - * indices and ratings, for the i'th product that will be sent to us by product block b (call this - * P). These arrays represent the users that product P had ratings for (by their index in this - * block), as well as the corresponding rating for each one. We can thus use this information when - * we get product block b's message to update the corresponding users. - */ -private[recommendation] case class InLinkBlock( - elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) - /** * :: Experimental :: @@ -201,6 +167,8 @@ class ALS private ( */ @DeveloperApi def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + require(storageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") this.intermediateRDDStorageLevel = storageLevel this } @@ -236,431 +204,39 @@ class ALS private ( this.numProductBlocks } - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val (userInLinks, userOutLinks) = - makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner) - val (productInLinks, productOutLinks) = - makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner) - userInLinks.setName("userInLinks") - userOutLinks.setName("userOutLinks") - productInLinks.setName("productInLinks") - productOutLinks.setName("productOutLinks") - - // Initialize user and product factors randomly, but use a deterministic seed for each - // partition so that fault recovery works - val seedGen = new Random(seed) - val seed1 = seedGen.nextInt() - val seed2 = seedGen.nextInt() - var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed1 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed2 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - - if (implicitPrefs) { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - // Persist users because it will be called twice. - users.setName(s"users-$iter").persist() - val YtY = Some(sc.broadcast(computeYtY(users))) - val previousProducts = products - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY) - previousProducts.unpersist() - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter").persist() - val XtX = Some(sc.broadcast(computeYtY(products))) - val previousUsers = users - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, XtX) - previousUsers.unpersist() - } - } else { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY = None) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter") - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, YtY = None) - users.setName(s"users-$iter") - } + val (floatUserFactors, floatProdFactors) = NewALS.train[Int]( + ratings = ratings.map(r => NewALS.Rating(r.user, r.product, r.rating.toFloat)), + rank = rank, + numUserBlocks = numUserBlocks, + numItemBlocks = numProductBlocks, + maxIter = iterations, + regParam = lambda, + implicitPrefs = implicitPrefs, + alpha = alpha, + nonnegative = nonnegative, + intermediateRDDStorageLevel = intermediateRDDStorageLevel, + finalRDDStorageLevel = StorageLevel.NONE, + seed = seed) + + val userFactors = floatUserFactors + .mapValues(_.map(_.toDouble)) + .setName("users") + .persist(finalRDDStorageLevel) + val prodFactors = floatProdFactors + .mapValues(_.map(_.toDouble)) + .setName("products") + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userFactors.count() + prodFactors.count() } - - // The last `products` will be used twice. One to generate the last `users` and the other to - // generate `productsOut`. So we cache it for better performance. - products.setName("products").persist() - - // Flatten and cache the two final RDDs to un-block them - val usersOut = unblockFactors(users, userOutLinks) - val productsOut = unblockFactors(products, productOutLinks) - - usersOut.setName("usersOut").persist(finalRDDStorageLevel) - productsOut.setName("productsOut").persist(finalRDDStorageLevel) - - // Materialize usersOut and productsOut. - usersOut.count() - productsOut.count() - - products.unpersist() - - // Clean up. - userInLinks.unpersist() - userOutLinks.unpersist() - productInLinks.unpersist() - productOutLinks.unpersist() - - new MatrixFactorizationModel(rank, usersOut, productsOut) + new MatrixFactorizationModel(rank, userFactors, prodFactors) } /** * Java-friendly version of [[ALS.run]]. */ def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) - - /** - * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors - * for each user (or product), in a distributed fashion. - * - * @param factors the (block-distributed) user or product factor vectors - * @return YtY - whose value is only used in the implicit preference model - */ - private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { - val n = rank * (rank + 1) / 2 - val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => { - Y.foreach(y => dspr(1.0, wrapDoubleArray(y), L)) - L - }, combOp = (L1, L2) => { - L1.addi(L2) - }) - val YtY = new DoubleMatrix(rank, rank) - fillFullMatrix(LYtY, YtY) - YtY - } - - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param L the lower triangular part of the matrix packed in an array (row major) - */ - private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = { - val n = x.length - var i = 0 - var j = 0 - var idx = 0 - var axi = 0.0 - val xd = x.data - val Ld = L.data - while (i < n) { - axi = alpha * xd(i) - j = 0 - while (j <= i) { - Ld(idx) += axi * xd(j) - j += 1 - idx += 1 - } - i += 1 - } - } - - /** - * Wrap a double array in a DoubleMatrix without creating garbage. - * This is a temporary fix for jblas 1.2.3; it should be safe to move back to the - * DoubleMatrix(double[]) constructor come jblas 1.2.4. - */ - private def wrapDoubleArray(v: Array[Double]): DoubleMatrix = { - new DoubleMatrix(v.length, 1, v: _*) - } - - /** - * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs - */ - private def unblockFactors( - blockedFactors: RDD[(Int, Array[Array[Double]])], - outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = { - blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - } - - /** - * Make the out-links table for a block of the users (or products) dataset given the list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): OutLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks)) - for (r <- ratings) { - shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true - } - OutLinkBlock(userIds, shouldSend) - } - - /** - * Make the in-links table for a block of the users (or products) dataset given a list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): InLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val userIdToPos = userIds.zipWithIndex.toMap - // Split out our ratings by product block - val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating]) - for (r <- ratings) { - blockRatings(productPartitioner.getPartition(r.product)) += r - } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks) - for (productBlock <- 0 until numProductBlocks) { - // Create an array of (product, Seq(Rating)) ratings - val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray - // Sort them by product ID - val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { - def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = - a._1 - b._1 - } - Sorting.quickSort(groupedRatings)(ordering) - // Translate the user IDs to indices based on userIdToPos - ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => - (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) - } - } - InLinkBlock(userIds, ratingsForBlock) - } - - /** - * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for - * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid - * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. - */ - private def makeLinkRDDs( - numUserBlocks: Int, - numProductBlocks: Int, - ratingsByUserBlock: RDD[(Int, Rating)], - productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = { - val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks)) - val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map(_._2).toArray - val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) - val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) - Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, preservesPartitioning = true) - val inLinks = links.mapValues(_._1) - val outLinks = links.mapValues(_._2) - inLinks.persist(intermediateRDDStorageLevel) - outLinks.persist(intermediateRDDStorageLevel) - (inLinks, outLinks) - } - - /** - * Make a random factor vector with the given random. - */ - private def randomFactor(rank: Int, rand: Random): Array[Double] = { - // Choose a unit vector uniformly at random from the unit sphere, but from the - // "first quadrant" where all elements are nonnegative. This can be done by choosing - // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. - // This appears to create factorizations that have a slightly better reconstruction - // (<1%) compared picking elements uniformly at random in [0,1]. - val factor = Array.fill(rank)(abs(rand.nextGaussian())) - val norm = sqrt(factor.map(x => x * x).sum) - factor.map(x => x / norm) - } - - /** - * Compute the user feature vectors given the current products (or vice-versa). This first joins - * the products with their out-links to generate a set of messages to each destination block - * (specifically, the features for the products that user block cares about), then groups these - * by destination and joins them with the in-link info to figure out how to update each user. - * It returns an RDD of new feature vectors for each user block. - */ - private def updateFeatures( - numUserBlocks: Int, - products: RDD[(Int, Array[Array[Double]])], - productOutLinks: RDD[(Int, OutLinkBlock)], - userInLinks: RDD[(Int, InLinkBlock)], - rank: Int, - lambda: Double, - alpha: Double, - YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = { - productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) { - if (outLinkBlock.shouldSend(p)(userBlock)) { - toSend(userBlock) += factors(p) - } - } - toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(new HashPartitioner(numUserBlocks)) - .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => - updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) - } - } - - /** - * Compute the new feature vectors for a block of the users matrix given the list of factors - * it received from each product and its InLinkBlock. - */ - private def updateBlock(messages: Iterable[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]]) - : Array[Array[Double]] = - { - // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numProductBlocks = blockFactors.length - val numUsers = inLinkBlock.elementIds.length - - // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since - // the matrices are symmetric - val triangleSize = rank * (rank + 1) / 2 - val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) - val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) - - // Some temp variables to avoid memory allocation - val tempXtX = DoubleMatrix.zeros(triangleSize) - val fullXtX = DoubleMatrix.zeros(rank, rank) - - // Count the number of ratings each user gives to provide user-specific regularization - val numRatings = Array.fill(numUsers)(0) - - // Compute the XtX and Xy values for each user by adding products it rated in each product - // block - for (productBlock <- 0 until numProductBlocks) { - var p = 0 - while (p < blockFactors(productBlock).length) { - val x = wrapDoubleArray(blockFactors(productBlock)(p)) - tempXtX.fill(0.0) - dspr(1.0, x, tempXtX) - val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) - if (implicitPrefs) { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - // Extension to the original paper to handle rs(i) < 0. confidence is a function - // of |rs(i)| instead so that it is never negative: - val confidence = 1 + alpha * abs(rs(i)) - SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i))) - // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) - // means we try to reconstruct 0. We add terms only where P = 1, so, term below - // is now only added for rs(i) > 0: - if (rs(i) > 0) { - SimpleBlas.axpy(confidence, x, userXy(us(i))) - } - i += 1 - } - } else { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - i += 1 - } - } - p += 1 - } - } - - val ws = if (nonnegative) NNLS.createWorkspace(rank) else null - - // Solve the least-squares problem for each user and return the new feature vectors - Array.range(0, numUsers).map { index => - // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(userXtX(index), fullXtX) - // Add regularization - val regParam = numRatings(index) * lambda - var i = 0 - while (i < rank) { - fullXtX.data(i * rank + i) += regParam - i += 1 - } - // Solve the resulting matrix, which is symmetric and positive-definite - if (implicitPrefs) { - solveLeastSquares(fullXtX.addi(YtY.get.value), userXy(index), ws) - } else { - solveLeastSquares(fullXtX, userXy(index), ws) - } - } - } - - /** - * Given A^T A and A^T b, find the x minimising ||Ax - b||_2, possibly subject - * to nonnegativity constraints if `nonnegative` is true. - */ - private def solveLeastSquares(ata: DoubleMatrix, atb: DoubleMatrix, - ws: NNLS.Workspace): Array[Double] = { - if (!nonnegative) { - Solve.solvePositive(ata, atb).data - } else { - NNLS.solve(ata, atb, ws) - } - } - - /** - * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square - * matrix that it represents, storing it into destMatrix. - */ - private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { - val rank = destMatrix.rows - var i = 0 - var pos = 0 - while (i < rank) { - var j = 0 - while (j <= i) { - destMatrix.data(i*rank + j) = triangularMatrix.data(pos) - destMatrix.data(j*rank + i) = triangularMatrix.data(pos) - pos += 1 - j += 1 - } - i += 1 - } - } -} - -/** - * Partitioner for ALS. - */ -private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = { - Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions) - } - - override def equals(obj: Any): Boolean = { - obj match { - case p: ALSPartitioner => - this.numPartitions == p.numPartitions - case _ => - false - } - } } /** @@ -834,120 +410,4 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } - - /** - * :: DeveloperApi :: - * Statistics of a block in ALS computation. - * - * @param category type of this block, "user" or "product" - * @param index index of this block - * @param count number of users or products inside this block, the same as the number of - * least-squares problems to solve on this block in each iteration - * @param numRatings total number of ratings inside this block, the same as the number of outer - * products we need to make on this block in each iteration - * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve - * before each iteration - * @param numOutLinks total number of outgoing links, the same as the number of vectors to send - * for the next iteration - */ - @DeveloperApi - case class BlockStats( - category: String, - index: Int, - count: Long, - numRatings: Long, - numInLinks: Long, - numOutLinks: Long) - - /** - * :: DeveloperApi :: - * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the - * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing - * load balance. - * - * @param ratings an RDD of ratings - * @param numUserBlocks number of user blocks - * @param numProductBlocks number of product blocks - * @return statistics of user blocks and product blocks - */ - @DeveloperApi - def analyzeBlocks( - ratings: RDD[Rating], - numUserBlocks: Int, - numProductBlocks: Int): Array[BlockStats] = { - - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val als = new ALS() - val (userIn, userOut) = - als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner) - val (prodIn, prodOut) = - als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner) - - def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = { - outLinks.map { x => - val grid = new mutable.HashMap[(Int, Int), Long]() - val uPartition = x._1 - x._2.shouldSend.foreach { ss => - ss.foreach { pPartition => - val pair = (uPartition, pPartition) - grid.put(pair, grid.getOrElse(pair, 0L) + 1L) - } - } - grid - }.reduce { (grid1, grid2) => - grid2.foreach { x => - grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2) - } - grid1 - }.toMap - } - - val userSendGrid = sendGrid(userOut) - val prodSendGrid = sendGrid(prodOut) - - val userInbound = new Array[Long](numUserBlocks) - val prodInbound = new Array[Long](numProductBlocks) - val userOutbound = new Array[Long](numUserBlocks) - val prodOutbound = new Array[Long](numProductBlocks) - - for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) { - userOutbound(u) += userSendGrid.getOrElse((u, p), 0L) - prodInbound(p) += userSendGrid.getOrElse((u, p), 0L) - userInbound(u) += prodSendGrid.getOrElse((p, u), 0L) - prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L) - } - - val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap() - val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap() - - val userRatings = countRatings(userIn) - val prodRatings = countRatings(prodIn) - - val userStats = Array.tabulate(numUserBlocks)( - u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u))) - val productStatus = Array.tabulate(numProductBlocks)( - p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p))) - - (userStats ++ productStatus).toArray - } - - private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = { - inLinks.mapValues { ilb => - var numRatings = 0L - ilb.ratingsForBlock.foreach { ar => - ar.foreach { p => numRatings += p._1.length } - } - numRatings - }.collectAsMap().toMap - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b549b7c475fc3..44a8dbb994cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.streaming.dstream.DStream /** @@ -58,14 +58,14 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: M + protected var model: Option[M] = None /** The algorithm to use for updating. */ protected val algorithm: A /** Return the latest model. */ def latestModel(): M = { - model + model.get } /** @@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing labeled data */ def trainOn(data: DStream[LabeledPoint]) { - if (Option(model.weights) == None) { - logError("Initial weights must be set before starting training") - throw new IllegalArgumentException + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = algorithm.run(rdd, model.weights) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.weights.size match { - case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.weights.toArray.mkString("[", ",", "]") + val initialWeights = + model match { + case Some(m) => + m.weights + case None => + val numFeatures = rdd.first().features.size + Vectors.dense(numFeatures) } - logInfo("Current model: weights, %s".format (display)) + model = Some(algorithm.run(rdd, initialWeights)) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) } } @@ -99,12 +106,10 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Double] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.predict) + data.map(model.get.predict) } /** @@ -114,11 +119,9 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.predict) + data.mapValues(model.get.predict) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 1d11fde24712c..e5e6301127a28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector /** + * :: Experimental :: * Train or predict a linear regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) @@ -41,13 +42,12 @@ import org.apache.spark.mllib.linalg.Vector * */ @Experimental -class StreamingLinearRegressionWithSGD ( +class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, - private var miniBatchFraction: Double, - private var initialWeights: Vector) - extends StreamingLinearAlgorithm[ - LinearRegressionModel, LinearRegressionWithSGD] with Serializable { + private var miniBatchFraction: Double) + extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] + with Serializable { /** * Construct a StreamingLinearRegression object with default parameters: @@ -55,12 +55,10 @@ class StreamingLinearRegressionWithSGD ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ - def this() = this(0.1, 50, 1.0, null) + def this() = this(0.1, 50, 1.0) val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) - var model = algorithm.createModel(initialWeights, 0.0) - /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) @@ -81,7 +79,7 @@ class StreamingLinearRegressionWithSGD ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = algorithm.createModel(initialWeights, 0.0) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java new file mode 100644 index 0000000000000..dc10aa67c7c1f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; + +import org.apache.spark.api.java.JavaRDD; +import scala.Tuple2; + +import org.junit.After; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; + + +public class JavaLDASuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLDA"); + ArrayList> tinyCorpus = new ArrayList>(); + for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), + LDASuite$.MODULE$.tinyCorpus()[i]._2())); + } + JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void localLDAModel() { + LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + + // Check: basic parameters + assertEquals(model.k(), tinyK); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), tinyTopics); + + // Check: describeTopics() with all terms + Tuple2[] fullTopicSummary = model.describeTopics(); + assertEquals(fullTopicSummary.length, tinyK); + for (int i = 0; i < fullTopicSummary.length; i++) { + assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1()); + assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5); + } + } + + @Test + public void distributedLDAModel() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345); + + DistributedLDAModel model = lda.run(corpus); + + // Check: basic parameters + LocalLDAModel localModel = model.toLocal(); + assertEquals(model.k(), k); + assertEquals(localModel.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(localModel.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + + // Check: log probabilities + assert(model.logLikelihood() < 0.0); + assert(model.logPrior() < 0.0); + } + + private static int tinyK = LDASuite$.MODULE$.tinyK(); + private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static Tuple2[] tinyTopicDescription = + LDASuite$.MODULE$.tinyTopicDescription(); + JavaPairRDD corpus; + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 0000000000000..851707c8a19c4 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; +import static org.junit.Assert.*; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowth fpg = new FPGrowth(); + + FPGrowthModel model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd); + assertEquals(0, model6.javaFreqItemsets().count()); + + FPGrowthModel model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + assertEquals(18, model3.javaFreqItemsets().count()); + + FPGrowthModel model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd); + assertEquals(54, model2.javaFreqItemsets().count()); + + FPGrowthModel model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd); + assertEquals(625, model1.javaFreqItemsets().count()); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ee08c3c32760e..acc447742bad0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -414,7 +414,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { - testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03, numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala new file mode 100644 index 0000000000000..8b3e6e5ce9249 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { + + // use longer wait time to ensure job completion + override def maxWaitTimeMillis = 30000 + + // Test if we can accurately learn B for Y = logistic(BX) on streaming data + test("parameter accuracy") { + + val nPoints = 100 + val B = 1.5 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check accuracy of final parameter estimates + assert(model.latestModel().weights(0) ~== B relTol 0.1) + + } + + // Test that parameter estimates improve when learning Y = logistic(BX) on streaming data + test("parameter convergence") { + + val B = 1.5 + val nPoints = 100 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) + + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // compute change in error + val deltas = history.drop(1).zip(history.dropRight(1)) + // check error stability (it always either shrinks, or increases with small tol) + assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) + // check that error shrunk on at least 2 batches + assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) + } + + // Test predictions on a stream + test("predictions") { + + val B = 1.5 + val nPoints = 100 + + // create model initialized with true weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(1.5)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // check that at least 60% of predictions are correct on all batches + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + + assert(errors.forall(x => x <= 0.4)) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 0000000000000..302d751eb8a94 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LocalLDAModel") { + val model = new LocalLDAModel(tinyTopics) + + // Check: basic parameters + assert(model.k === tinyK) + assert(model.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === tinyTopics) + + // Check: describeTopics() with all terms + val fullTopicSummary = model.describeTopics() + assert(fullTopicSummary.size === tinyK) + fullTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms) + assert(algTermWeights === termWeights) + } + + // Check: describeTopics() with some terms + val smallNumTerms = 3 + val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms) + smallTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms.slice(0, smallNumTerms)) + assert(algTermWeights === termWeights.slice(0, smallNumTerms)) + } + } + + test("running and DistributedLDAModel") { + val k = 3 + val topicSmoothing = 1.2 + val termSmoothing = 1.2 + + // Train a model + val lda = new LDA() + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + + val model: DistributedLDAModel = lda.run(corpus) + + // Check: basic parameters + val localModel = model.toLocal + assert(model.k === k) + assert(localModel.k === k) + assert(model.vocabSize === tinyVocabSize) + assert(localModel.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === localModel.topicsMatrix) + + // Check: topic summaries + // The odd decimal formatting and sorting is a hack to do a robust comparison. + val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => + assert(t1 === t2) + } + + // Check: per-doc topic distributions + val topicDistributions = model.topicDistributions.collect() + // Ensure all documents are covered. + assert(topicDistributions.size === tinyCorpus.size) + assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) + // Ensure we have proper distributions + topicDistributions.foreach { case (docId, topicDistribution) => + assert(topicDistribution.size === tinyK) + assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) + } + + // Check: log probabilities + assert(model.logLikelihood < 0.0) + assert(model.logPrior < 0.0) + } + + test("vertex indexing") { + // Check vertex ID indexing and conversions. + val docIds = Array(0, 1, 2) + val docVertexIds = docIds + val termIds = Array(0, 1, 2) + val termVertexIds = Array(-1, -2, -3) + assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0)))) + assert(termIds.map(LDA.term2index) === termVertexIds) + assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) + assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) + } +} + +private[clustering] object LDASuite { + + def tinyK: Int = 3 + def tinyVocabSize: Int = 5 + def tinyTopicsAsArray: Array[Array[Double]] = Array( + Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0 + Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1 + Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2 + ) + def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK, + values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _)) + def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic => + val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + + def tinyCorpus = Array( + Vectors.dense(1, 3, 0, 2, 8), + Vectors.dense(0, 2, 1, 0, 4), + Vectors.dense(2, 3, 12, 3, 1), + Vectors.dense(0, 3, 1, 9, 8), + Vectors.dense(1, 1, 4, 2, 6) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 71ef60da6dd32..68128284b8608 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { - test("FP-Growth") { + + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .run(rdd) assert(model1.freqItemsets.count() === 625) } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") + val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => + (items.toSet, count) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 65) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala new file mode 100644 index 0000000000000..dac28a369b5b2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.scalatest.FunSuite + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { + + import PeriodicGraphCheckpointerSuite._ + + // TODO: Do I need to call count() on the graphs' RDDs? + + test("Persisting") { + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicGraphCheckpointerSuite { + + case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) + + val edges = Seq( + Edge[Double](0, 1, 0), + Edge[Double](1, 2, 0), + Edge[Double](2, 3, 0), + Edge[Double](3, 4, 0)) + + def createGraph(sc: SparkContext): Graph[Double, Double] = { + Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + } + + def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { + graphs.foreach { g => + checkPersistence(g.graph, g.gIndex, iteration) + } + } + + /** + * Check storage level of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(graph.vertices.getStorageLevel == StorageLevel.NONE) + assert(graph.edges.getStorageLevel == StorageLevel.NONE) + } else { + assert(graph.vertices.getStorageLevel != StorageLevel.NONE) + assert(graph.edges.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + + s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") + } + } + + def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { + graphs.reverse.foreach { g => + checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { + // Note: We cannot check graph.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this graph.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + graph.getCheckpointFiles.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), + "Graph checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkCheckpoint( + graph: Graph[_, _], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) + // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(graph.isCheckpointed, "Graph should be checkpointed") + assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(graph) + } + } else { + // Graph should never be checkpointed + assert(!graph.isCheckpointed, "Graph should never have been checkpointed") + assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + + s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index e9fc37e000526..8775c0ca9df84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -24,9 +24,7 @@ import scala.util.Random import org.scalatest.FunSuite import org.jblas.DoubleMatrix -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.recommendation.ALS.BlockStats import org.apache.spark.storage.StorageLevel object ALSSuite { @@ -189,22 +187,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } - test("analyze one user block and one product block") { - val localRatings = Seq( - Rating(0, 100, 1.0), - Rating(0, 101, 2.0), - Rating(0, 102, 3.0), - Rating(1, 102, 4.0), - Rating(2, 103, 5.0)) - val ratings = sc.makeRDD(localRatings, 2) - val stats = ALS.analyzeBlocks(ratings, 1, 1) - assert(stats.size === 2) - assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3)) - assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4)) - } - - // TODO: add tests for analyzing multiple user/product blocks - /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * diff --git a/pom.xml b/pom.xml index e25eced877578..542efbaf06eb0 100644 --- a/pom.xml +++ b/pom.xml @@ -149,6 +149,7 @@ 2.10 ${scala.version} org.scala-lang + 3.6.3 1.8.8 1.1.1.6 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 78de1f0652741..b17532c1d814c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -69,7 +69,12 @@ object MimaExcludes { ) ++ Seq( // SPARK-5540 ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares") + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 6b713aa39374e..f6b97abb1723c 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,19 +15,22 @@ # limitations under the License. # +from numpy import array + +from pyspark import RDD from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, callJavaFunc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.stat.distribution import MultivariateGaussian -__all__ = ['KMeansModel', 'KMeans'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] class KMeansModel(object): """A clustering model derived from the k-means method. - >>> from numpy import array - >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0])) @@ -86,6 +89,87 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) +class GaussianMixtureModel(object): + + """A clustering model derived from the Gaussian Mixture Model method. + + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, + ... 0.9,0.8,0.75,0.935, + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, + ... maxIterations=50, seed=10) + >>> labels = model.predict(clusterdata_1).collect() + >>> labels[0]==labels[1] + False + >>> labels[1]==labels[2] + True + >>> labels[4]==labels[5] + True + >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, + ... maxIterations=150, seed=10) + >>> labels = model.predict(clusterdata_2).collect() + >>> labels[0]==labels[1]==labels[2] + True + >>> labels[3]==labels[4] + True + """ + + def __init__(self, weights, gaussians): + self.weights = weights + self.gaussians = gaussians + self.k = len(self.weights) + + def predict(self, x): + """ + Find the cluster to which the points in 'x' has maximum membership + in this model. + + :param x: RDD of data points. + :return: cluster_labels. RDD of cluster labels. + """ + if isinstance(x, RDD): + cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) + return cluster_labels + + def predictSoft(self, x): + """ + Find the membership of each point in 'x' to all mixture components. + + :param x: RDD of data points. + :return: membership_matrix. RDD of array of double values. + """ + if isinstance(x, RDD): + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) + membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), + self.weights, means, sigmas) + return membership_matrix + + +class GaussianMixture(object): + """ + Estimate model parameters with the expectation-maximization algorithm. + + :param data: RDD of data points + :param k: Number of components + :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 + :param maxIterations: Number of iterations. Default to 100 + :param seed: Random Seed + """ + @classmethod + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + """Train a Gaussian Mixture clustering model.""" + weight, mu, sigma = callMLlibFunc("trainGaussianMixture", + rdd.map(_convert_to_vector), k, + convergenceTol, maxIterations, seed) + mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] + return GaussianMixtureModel(weight, mvg_obj) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 97ec74eda0b71..0d99e6dedfad9 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -49,17 +49,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) - >>> model.predict(2,2) - 0.4473... + >>> model.predict(2, 2) + 0.43... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1, seed=10) + >>> model = ALS.train(ratings, 2, seed=0) >>> model.predictAll(testset).collect() - [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] + [Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] @@ -67,7 +67,7 @@ class MatrixFactorizationModel(JavaModelWrapper): True >>> model.productFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] @@ -76,11 +76,11 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 3.735... + 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.4473... + 0.43... """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index 799d260c096b1..b686d955a0080 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -20,5 +20,6 @@ """ from pyspark.mllib.stat._statistics import * +from pyspark.mllib.stat.distribution import MultivariateGaussian -__all__ = ["Statistics", "MultivariateStatisticalSummary"] +__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"] diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py new file mode 100644 index 0000000000000..07792e1532046 --- /dev/null +++ b/python/pyspark/mllib/stat/distribution.py @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ['MultivariateGaussian'] + + +class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): + + """ Represents a (mu, sigma) tuple + >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) + >>> (m.mu, m.sigma.toArray()) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + >>> (m[0], m[1]) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 61e0cf5d90bd0..42aa22873772d 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -167,6 +167,32 @@ def test_kmeans_deterministic(self): # TODO: Allow small numeric difference. self.assertTrue(array_equal(c1, c2)) + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=100, seed=56) + labels = clusters.predict(data).collect() + self.assertEquals(labels[0], labels[1]) + self.assertEquals(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEquals(round(c1, 7), round(c2, 7)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 32bff0c7e8c55..268c7ef97cffc 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -62,7 +62,7 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl", "SchemaRDD"] @@ -1804,7 +1804,7 @@ class DataFrame(object): people = sqlContext.parquetFile("...") Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: [[DataFrame]], [[Column]]. + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. To select a column from the data frame, use the apply method:: @@ -1835,8 +1835,10 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): - """Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row`s. """ + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) @@ -1850,18 +1852,6 @@ def applySchema(it): return self._lazy_rdd - def limit(self, num): - """Limit the result count to the number specified. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - def toJSON(self, use_unicode=False): """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. @@ -1886,7 +1876,6 @@ def saveAsParquetFile(self, path): >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) >>> df.saveAsParquetFile(parquetFile) >>> df2 = sqlCtx.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) @@ -1900,9 +1889,8 @@ def registerTempTable(self, name): The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this DataFrame. - >>> df = sqlCtx.inferSchema(rdd) - >>> df.registerTempTable("test") - >>> df2 = sqlCtx.sql("select * from test") + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -1926,11 +1914,22 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this DataFrame (represented by - a L{StructType}).""" + a L{StructType}). + + >>> df.schema() + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): - """Prints out the schema in the tree format.""" + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ print (self._jdf.schema().treeString()) def count(self): @@ -1940,11 +1939,8 @@ def count(self): leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> df = sqlCtx.inferSchema(rdd) >>> df.count() - 3L - >>> df.count() == df.map(lambda x: x).count() - True + 2L """ return self._jdf.count() @@ -1954,13 +1950,11 @@ def collect(self): Each object in the list is a Row, the fields can be accessed as attributes. - >>> df = sqlCtx.inferSchema(rdd) >>> df.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: bytesInJava = self._jdf.javaToPython().collect().iterator() - cls = _create_cls(self.schema()) tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) tempFile.close() self._sc._writeToFile(bytesInJava, tempFile.name) @@ -1968,23 +1962,37 @@ def collect(self): with open(tempFile.name, 'rb') as tempFile: rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) os.unlink(tempFile.name) + cls = _create_cls(self.schema()) return [cls(r) for r in rs] + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + def take(self, num): """Take the first num rows of the RDD. Each object in the list is a Row, the fields can be accessed as attributes. - >>> df = sqlCtx.inferSchema(rdd) >>> df.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ return self.limit(num).collect() def map(self, f): """ Return a new RDD by applying a function to each Row, it's a shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] """ return self.rdd.map(f) @@ -2067,140 +2075,167 @@ def sample(self, withReplacement, fraction, seed=None): @property def dtypes(self): """Return all column names and their data types as a list. + + >>> df.dtypes + [(u'age', 'IntegerType'), (u'name', 'StringType')] """ return [(f.name, str(f.dataType)) for f in self.schema().fields] @property def columns(self): """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] """ return [f.name for f in self.schema().fields] - def show(self): - raise NotImplemented - def join(self, other, joinExprs=None, joinType=None): """ Join with another DataFrame, using the given join expression. The following performs a full outer join between `df1` and `df2`:: - df1.join(df2, df1.key == df2.key, "outer") - :param other: Right side of the join :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, - `semijoin`. + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] """ - if joinType is None: - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - jdf = self._jdf.join(other._jdf, joinExprs) + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) else: - jdf = self._jdf.join(other._jdf, joinExprs, joinType) + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new [[DataFrame]] sorted by the specified column, - in ascending column. + """ Return a new :class:`DataFrame` sorted by the specified column. :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sortBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") - for i, c in enumerate(cols): - if isinstance(c, basestring): - cols[i] = Column(c) - jcols = [c._jc for c in cols] - jdf = self._jdf.join(*jcols) + jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(_to_java_column(cols[0]), + self._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) sortBy = sort def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. """ + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n) def first(self): - """ Return the first row. """ - return self.head() + """ Return the first row. - def tail(self): - raise NotImplemented + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() def __getitem__(self, item): + """ Return the column by given name + + >>> df['age'].collect() + [Row(age=2), Row(age=5)] + """ if isinstance(item, basestring): - return Column(self._jdf.apply(item)) + jc = self._jdf.apply(item) + return Column(jc, self.sql_ctx) # TODO projection raise IndexError def __getattr__(self, name): - """ Return the column by given name """ + """ Return the column by given name + + >>> df.age.collect() + [Row(age=2), Row(age=5)] + """ if name.startswith("__"): raise AttributeError(name) - return Column(self._jdf.apply(name)) - - def alias(self, name): - """ Alias the current DataFrame """ - return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + jc = self._jdf.apply(name) + return Column(jc, self.sql_ctx) def select(self, *cols): - """ Selecting a set of expressions.:: - - df.select() - df.select('colA', 'colB') - df.select(df.colA, df.colB + 1) - + """ Selecting a set of expressions. + + >>> df.select().collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).As('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ if not cols: cols = ["*"] - if isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition:: - - df.filter(df.age > 15) - df.where(df.age > 15) + """ Filtering rows using the given condition. + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] """ return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) where = filter def groupBy(self, *cols): - """ Group the [[DataFrame]] using the specified columns, + """ Group the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedDataFrame` - for all the available aggregate functions:: - - df.groupBy(df.department).avg() - df.groupBy("department", "gender").agg({ - "salary": "avg", - "age": "max", - }) + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] """ - if cols and isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): - """ Aggregate on the entire [[DataFrame]] without groups - (shorthand for df.groupBy.agg()):: - - df.agg({"age": "max", "salary": "avg"}) + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=2)] """ return self.groupBy().agg(*exprs) @@ -2213,7 +2248,7 @@ def unionAll(self, other): return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) def intersect(self, other): - """ Return a new [[DataFrame]] containing rows only in + """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. This is equivalent to `INTERSECT` in SQL. @@ -2221,7 +2256,7 @@ def intersect(self, other): return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) def subtract(self, other): - """ Return a new [[DataFrame]] containing rows in this frame + """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. This is equivalent to `EXCEPT` in SQL. @@ -2229,7 +2264,11 @@ def subtract(self, other): return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. """ + """ Return a new DataFrame by sampling a fraction of rows. + + >>> df.sample(False, 0.5, 10).collect() + [Row(age=2, name=u'Alice')] + """ if seed is None: jdf = self._jdf.sample(withReplacement, fraction) else: @@ -2237,11 +2276,12 @@ def sample(self, withReplacement, fraction, seed=None): return DataFrame(jdf, self.sql_ctx) def addColumn(self, colName, col): - """ Return a new [[DataFrame]] by adding a column. """ - return self.select('*', col.alias(colName)) + """ Return a new :class:`DataFrame` by adding a column. - def removeColumn(self, colName): - raise NotImplemented + >>> df.addColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.As(colName)) # Having SchemaRDD for backward compatibility (for docs) @@ -2280,7 +2320,14 @@ def agg(self, *exprs): `sum`, `count`. :param exprs: list or aggregate columns or a map from column - name to agregate methods. + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"age": "max"}).collect() + [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + >>> from pyspark.sql import Dsl + >>> gdf.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -2297,7 +2344,11 @@ def agg(self, *exprs): @dfapi def count(self): - """ Count the number of rows for each group. """ + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ @dfapi def mean(self): @@ -2349,18 +2400,25 @@ def sum(self): def _create_column_from_literal(literal): sc = SparkContext._active_spark_context - return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) + return sc._jvm.Dsl.lit(literal) def _create_column_from_name(name): sc = SparkContext._active_spark_context - return sc._jvm.IncomputableColumn(name) + return sc._jvm.Dsl.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol def _scalaMethod(name): """ Translate operators into methodName in Scala - For example: >>> _scalaMethod('+') '$plus' >>> _scalaMethod('>=') @@ -2371,37 +2429,34 @@ def _scalaMethod(name): return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) -def _unary_op(name): +def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): - return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + jc = getattr(self._jc, _scalaMethod(name))() + return Column(jc, self.sql_ctx) + _.__doc__ = doc return _ -def _bin_op(name, pass_literal_through=True): +def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator - - Keyword arguments: - pass_literal_through -- whether to pass literal value directly through to the JVM. """ def _(self, other): - if isinstance(other, Column): - jc = other._jc - else: - if pass_literal_through: - jc = other - else: - jc = _create_column_from_literal(other) - return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, _scalaMethod(name))(jc) + return Column(njc, self.sql_ctx) + _.__doc__ = doc return _ -def _reverse_op(name): +def _reverse_op(name, doc="binary operator"): """ Create a method for binary operator (this object is on right side) """ def _(self, other): - return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), - self._jdf, self.sql_ctx) + jother = _create_column_from_literal(other) + jc = getattr(jother, _scalaMethod(name))(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc return _ @@ -2410,20 +2465,20 @@ class Column(DataFrame): """ A column in a DataFrame. - `Column` instances can be created by: - {{{ - // 1. Select a column out of a DataFrame - df.colName - df["colName"] + `Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + df.colName + df["colName"] - // 2. Create from an expression - df["colName"] + 1 - }}} + # 2. Create from an expression + df.colName + 1 + 1 / df.colName """ - def __init__(self, jc, jdf=None, sql_ctx=None): + def __init__(self, jc, sql_ctx=None): self._jc = jc - super(Column, self).__init__(jdf, sql_ctx) + super(Column, self).__init__(jc, sql_ctx) # arithmetic operators __neg__ = _unary_op("unary_-") @@ -2438,8 +2493,6 @@ def __init__(self, jc, jdf=None, sql_ctx=None): __rdiv__ = _reverse_op("/") __rmod__ = _reverse_op("%") __abs__ = _unary_op("abs") - abs = _unary_op("abs") - sqrt = _unary_op("sqrt") # logistic operators __eq__ = _bin_op("===") @@ -2448,47 +2501,45 @@ def __init__(self, jc, jdf=None, sql_ctx=None): __le__ = _bin_op("<=") __ge__ = _bin_op(">=") __gt__ = _bin_op(">") - # `and`, `or`, `not` cannot be overloaded in Python - And = _bin_op('&&') - Or = _bin_op('||') - Not = _unary_op('unary_!') - - # bitwise operators - __and__ = _bin_op("&") - __or__ = _bin_op("|") - __invert__ = _unary_op("unary_~") - __xor__ = _bin_op("^") - # __lshift__ = _bin_op("<<") - # __rshift__ = _bin_op(">>") - __rand__ = _bin_op("&") - __ror__ = _bin_op("|") - __rxor__ = _bin_op("^") - # __rlshift__ = _reverse_op("<<") - # __rrshift__ = _reverse_op(">>") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('&&') + __or__ = _bin_op('||') + __invert__ = _unary_op('unary_!') + __rand__ = _bin_op("&&") + __ror__ = _bin_op("||") # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - # __getattr__ = _bin_op("getField") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") # string methods rlike = _bin_op("rlike") like = _bin_op("like") startswith = _bin_op("startsWith") endswith = _bin_op("endsWith") - upper = _unary_op("upper") - lower = _unary_op("lower") - def substr(self, startPos, pos): - if type(startPos) != type(pos): + def substr(self, startPos, length): + """ + Return a Column which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.name.substr(1, 3).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): raise TypeError("Can not mix the type") if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, pos) + jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, pos._jc) + jc = self._jc.substr(startPos._jc, length._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self._jdf, self.sql_ctx) + return Column(jc, self.sql_ctx) __getslice__ = substr @@ -2496,55 +2547,89 @@ def substr(self, startPos, pos): asc = _unary_op("asc") desc = _unary_op("desc") - isNull = _unary_op("isNull") - isNotNull = _unary_op("isNotNull") + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") # `as` is keyword def alias(self, alias): - return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + """Return a alias for this column + + >>> df.age.As("age2").collect() + [Row(age2=2), Row(age2=5)] + >>> df.age.alias("age2").collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + As = alias def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").As('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).As('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ if self.sql_ctx is None: sc = SparkContext._active_spark_context ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) else: ssql_ctx = self.sql_ctx._ssql_ctx - jdt = ssql_ctx.parseDataType(dataType.json()) - return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc, self.sql_ctx) -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _aggregate_func(name): +def _aggregate_func(name, doc=""): """ Create a function for aggregator by name""" def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) return Column(jc) - + _.__name__ = name + _.__doc__ = doc return staticmethod(_) -class Aggregator(object): +class Dsl(object): """ A collections of builtin aggregators """ - AGGS = [ - 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs', - 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct', - ] - for _name in AGGS: - locals()[_name] = _aggregate_func(_name) - del _name + DSLS = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + } + + for _name, _doc in DSLS.items(): + locals()[_name] = _aggregate_func(_name, _doc) + del _name, _doc @staticmethod def countDistinct(col, *cols): + """ Return a new Column for distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect() + [Row(c=2)] + """ sc = SparkContext._active_spark_context jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) @@ -2554,6 +2639,12 @@ def countDistinct(col, *cols): @staticmethod def approxCountDistinct(col, rsd=None): + """ Return a new Column for approxiate distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect() + [Row(c=2)] + """ sc = SparkContext._active_spark_context if rsd is None: jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) @@ -2568,16 +2659,20 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlCtx'] = sqlCtx = SQLContext(sc) globs['rdd'] = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) + globs['df'] = sqlCtx.inferSchema(rdd2) + globs['df2'] = sqlCtx.inferSchema(rdd3) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py new file mode 100644 index 0000000000000..d314f46e8d2d5 --- /dev/null +++ b/python/pyspark/sql_tests.py @@ -0,0 +1,299 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType +from pyspark.tests import ReusedPySparkTestCase + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql_tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import Dsl + self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) + self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c7d0622d65f25..b5e28c498040b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,7 +23,6 @@ from fileinput import input from glob import glob import os -import pydoc import re import shutil import subprocess @@ -52,8 +51,6 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -795,264 +792,6 @@ def heavy_foo(x): rdd.foreach(heavy_foo) -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - self.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(self.testData) - self.df = self.sqlCtx.inferSchema(rdd) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - df.count() - df.collect() - df.schema() - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist() - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) - self.assertEqual(10, df3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(df.schema(), df2.schema()) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - k, v = df.head().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - schema = df.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) - point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_column_operators(self): - from pyspark.sql import Column, LongType - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbit)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import Aggregator as Agg - self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) - self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0]) - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - class InputFormatTests(ReusedPySparkTestCase): @classmethod diff --git a/python/run-tests b/python/run-tests index e91f1a875d356..649a2c44d187b 100755 --- a/python/run-tests +++ b/python/run-tests @@ -65,6 +65,7 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." run_test "pyspark/sql.py" + run_test "pyspark/sql_tests.py" } function run_mllib_tests() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 41bb4f012f2e1..3a70d25534968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow - +import org.apache.spark.sql.types.DateUtils object Row { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e0db587efb08d..8e79e532ca564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ - /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -72,6 +71,7 @@ trait ScalaReflection { }.toArray) case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) + case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) case (other, _) => other } @@ -85,6 +85,7 @@ trait ScalaReflection { } case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal + case (i: Int, DateType) => DateUtils.toJavaDate(i) case (other, _) => other } @@ -159,7 +160,7 @@ trait ScalaReflection { valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) @@ -191,7 +192,7 @@ trait ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DateType.JvmType => DateType + case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 594a423146d77..5c006e9d4c6f5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -52,6 +52,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CAST = Keyword("CAST") protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") + protected val DATE = Keyword("DATE") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") @@ -147,8 +148,8 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) + INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o) } protected lazy val projection: Parser[Expression] = @@ -383,6 +384,7 @@ class SqlParser extends AbstractSparkSQLParser { | DOUBLE ^^^ DoubleType | fixedDecimalType | DECIMAL ^^^ DecimalType.Unlimited + | DATE ^^^ DateType ) protected lazy val fixedDecimalType: Parser[DataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cefd70acf3931..ae7f7b9feb5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -69,8 +69,9 @@ class Analyzer(catalog: Catalog, typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution, - CheckAggregation), + CheckResolution :: + CheckAggregation :: + Nil: _*), Batch("AnalysisOperators", fixedPoint, EliminateAnalysisOperators) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 760c49fbca4a5..9f334f6d42ad1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -27,23 +27,25 @@ trait FunctionRegistry { def registerFunction(name: String, builder: FunctionBuilder): Unit def lookupFunction(name: String, children: Seq[Expression]): Expression + + def caseSensitive: Boolean } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) def registerFunction(name: String, builder: FunctionBuilder) = { functionBuilders.put(name, builder) } abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) } } -class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() +class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) def registerFunction(name: String, builder: FunctionBuilder) = { functionBuilders.put(name, builder) @@ -64,4 +66,30 @@ object EmptyFunctionRegistry extends FunctionRegistry { def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } + + def caseSensitive: Boolean = ??? +} + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + * TODO move this into util folder? + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean) = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + def remove(key: String): Option[T] = base.remove(normalizer(key)) + def iterator: Iterator[(String, T)] = base.toIterator } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 66060289189ef..f35921e2a772c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -128,6 +128,44 @@ case class UnresolvedStar(table: Option[String]) extends Star { override def toString = table.map(_ + ".").getOrElse("") + "*" } +/** + * Used to assign new names to Generator's output, such as hive udtf. + * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented + * as follows: + * MultiAlias(stack_function, Seq(a, b)) + + * @param child the computation being performed + * @param names the names to be associated with each output of computing [[child]]. + */ +case class MultiAlias(child: Expression, names: Seq[String]) + extends Attribute with trees.UnaryNode[Expression] { + + override def name = throw new UnresolvedException(this, "name") + + override def exprId = throw new UnresolvedException(this, "exprId") + + override def dataType = throw new UnresolvedException(this, "dataType") + + override def nullable = throw new UnresolvedException(this, "nullable") + + override def qualifiers = throw new UnresolvedException(this, "qualifiers") + + override lazy val resolved = false + + override def newInstance = this + + override def withNullability(newNullability: Boolean) = this + + override def withQualifiers(newQualifiers: Seq[String]) = this + + override def withName(newName: String) = this + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + override def toString: String = s"$child AS $names" + +} /** * Represents all the resolved input attributes to a given relational operator. This is used diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ece5ee73618cb..b1bc858478ee1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -113,7 +113,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Date](_, dateToString) + case DateType => buildCast[Int](_, d => DateUtils.toString(d)) case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => // Hive would return null when cast from date to boolean - buildCast[Date](_, d => null) + buildCast[Int](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -171,7 +171,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => new Timestamp(b)) case DateType => - buildCast[Date](_, d => new Timestamp(d.getTime)) + buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -224,37 +224,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToDateString(ts: Timestamp): String = { - Cast.threadLocalDateFormat.get.format(ts) - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[String](_, s => - try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }) + try DateUtils.fromJavaDate(Date.valueOf(s)) + catch { case _: java.lang.IllegalArgumentException => null } + ) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) // Hive throws this exception as a Semantic Exception - // It is never possible to compare result when hive return with exception, so we can return null + // It is never possible to compare result when hive return with exception, + // so we can return null // NULL is more reasonable here, since the query itself obeys the grammar. case _ => _ => null } - // Date cannot be cast to long, according to hive - private[this] def dateToLong(d: Date) = null - - // Date cannot be cast to double, according to hive - private[this] def dateToDouble(d: Date) = null - - // Converts Date to string according to Hive DateWritable convention - private[this] def dateToString(d: Date): String = { - Cast.threadLocalDateFormat.get.format(d) - } - // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => @@ -264,7 +251,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case x: NumericType => @@ -280,7 +267,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case x: NumericType => @@ -296,7 +283,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case x: NumericType => @@ -312,7 +299,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case x: NumericType => @@ -342,7 +329,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => null) // date can't cast to decimal in Hive + buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) @@ -367,7 +354,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case x: NumericType => @@ -383,7 +370,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case x: NumericType => @@ -442,16 +429,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd") + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + new SimpleDateFormat("yyyy-MM-dd") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cae5c4718683..1f80d84b744a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -246,6 +246,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children + case Cast(child @ DateType(), StringType) => + child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5b389aad7a85d..97bb96f48e2c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -35,7 +35,7 @@ object Literal { case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) - case d: Date => Literal(d, DateType) + case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f388cd5972bac..e6ab1fd8d7939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -75,7 +75,7 @@ abstract class Attribute extends NamedExpression { /** * Used to assign a new name to a computation. * For example the SQL expression "1 + 1 AS a" could be represented as follows: - * Alias(Add(Literal(1), Literal(1), "a")() + * Alias(Add(Literal(1), Literal(1)), "a")() * * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala new file mode 100644 index 0000000000000..8a1a3b81b3d2c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.Date +import java.util.{Calendar, TimeZone} + +import org.apache.spark.sql.catalyst.expressions.Cast + +/** + * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + */ +object DateUtils { + private val MILLIS_PER_DAY = 86400000 + + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. + private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + override protected def initialValue: TimeZone = { + Calendar.getInstance.getTimeZone + } + } + + private def javaDateToDays(d: Date): Int = { + millisToDays(d.getTime) + } + + def millisToDays(millisLocal: Long): Int = { + ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + } + + private def toMillisSinceEpoch(days: Int): Long = { + val millisUtc = days.toLong * MILLIS_PER_DAY + millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + } + + def fromJavaDate(date: java.sql.Date): Int = { + javaDateToDays(date) + } + + def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { + new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) + } + + def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index defdcb2b706f5..4825d1ff81402 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag @@ -387,18 +387,16 @@ case object TimestampType extends NativeType { */ @DeveloperApi case object DateType extends NativeType { - private[sql] type JvmType = Date + private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Date, y: Date) = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[JvmType]] /** - * The default size of a value of the DateType is 8 bytes. + * The default size of a value of the DateType is 4 bytes. */ - override def defaultSize: Int = 8 + override def defaultSize: Int = 4 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 37e64adeea853..25d1c105a00a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -303,6 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { val sd = "1970-01-01" val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" val sts = sd + " 00:00:02" val nts = sts + ".1" val ts = Timestamp.valueOf(nts) @@ -319,14 +320,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) // all convert to string type to check checkEvaluation( Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -377,8 +378,8 @@ class ExpressionEvaluationSuite extends FunSuite { } test("date") { - val d1 = Date.valueOf("1970-01-01") - val d2 = Date.valueOf("1970-01-02") + val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) checkEvaluation(Literal(d1) < Literal(d2), true) } @@ -459,22 +460,21 @@ class ExpressionEvaluationSuite extends FunSuite { test("date casting") { val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(d, ShortType), null) - checkEvaluation(Cast(d, IntegerType), null) - checkEvaluation(Cast(d, LongType), null) - checkEvaluation(Cast(d, FloatType), null) - checkEvaluation(Cast(d, DoubleType), null) - checkEvaluation(Cast(d, DecimalType.Unlimited), null) - checkEvaluation(Cast(d, DecimalType(10, 2)), null) - checkEvaluation(Cast(d, StringType), "1970-01-01") - checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + checkEvaluation(Cast(Literal(d), ShortType), null) + checkEvaluation(Cast(Literal(d), IntegerType), null) + checkEvaluation(Cast(Literal(d), LongType), null) + checkEvaluation(Cast(Literal(d), FloatType), null) + checkEvaluation(Cast(Literal(d), DoubleType), null) + checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) + checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) + checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") + checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") } test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) - val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part val tss = new Timestamp(seconds) checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c147be9f6b1ae..7bcd6687d11a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -106,7 +106,7 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) - checkDefaultSize(DateType, 8) + checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1a0c77d282307..03a5c9e7c24a0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -66,6 +66,11 @@ jackson-databind 2.3.0 + + org.jodd + jodd-core + ${jodd.version} + junit junit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6f48d7c3fe1b6..ddce77deb83e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -56,7 +56,7 @@ private[sql] object Column { * */ // TODO: Improve documentation. -trait Column extends DataFrame with ExpressionApi { +trait Column extends DataFrame { protected[sql] def expr: Expression @@ -65,7 +65,15 @@ trait Column extends DataFrame with ExpressionApi { */ def isComputable: Boolean - private def constructColumn(other: Column)(newExpr: Expression): Column = { + private def computableCol(baseCol: ComputableColumn, expr: Expression) = { + val plan = Project(Seq(expr match { + case named: NamedExpression => named + case unnamed: Expression => Alias(unnamed, "col")() + }), baseCol.plan) + Column(baseCol.sqlContext, plan, expr) + } + + private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = { // Removes all the top level projection and subquery so we can get to the underlying plan. @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match { case Project(_, child) => stripProject(child) @@ -73,457 +81,448 @@ trait Column extends DataFrame with ExpressionApi { case _ => p } - def computableCol(baseCol: ComputableColumn, expr: Expression) = { - val plan = Project(Seq(expr match { - case named: NamedExpression => named - case unnamed: Expression => Alias(unnamed, "col")() - }), baseCol.plan) - Column(baseCol.sqlContext, plan, expr) - } - - (this, other) match { + (this, lit(otherValue)) match { case (left: ComputableColumn, right: ComputableColumn) => if (stripProject(left.plan).sameResult(stripProject(right.plan))) { - computableCol(right, newExpr) + computableCol(right, newExpr(right)) } else { - Column(newExpr) + Column(newExpr(right)) } - case (left: ComputableColumn, _) => computableCol(left, newExpr) - case (_, right: ComputableColumn) => computableCol(right, newExpr) - case (_, _) => Column(newExpr) + case (left: ComputableColumn, right) => computableCol(left, newExpr(right)) + case (_, right: ComputableColumn) => computableCol(right, newExpr(right)) + case (_, right) => Column(newExpr(right)) + } + } + + /** Creates a column based on the given expression. */ + private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = { + this match { + case c: ComputableColumn if computable => computableCol(c, newExpr) + case _ => Column(newExpr) } } /** * Unary minus, i.e. negate the expression. * {{{ - * // Select the amount column and negates all values. + * // Scala: select the amount column and negates all values. * df.select( -df("amount") ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.select( negate(col("amount") ); * }}} */ - override def unary_- : Column = constructColumn(null) { UnaryMinus(expr) } - - /** - * Bitwise NOT. - * {{{ - * // Select the flags column and negate every bit. - * df.select( ~df("flags") ) - * }}} - */ - override def unary_~ : Column = constructColumn(null) { BitwiseNot(expr) } + def unary_- : Column = exprToColumn(UnaryMinus(expr)) /** * Inversion of boolean expression, i.e. NOT. * {{ - * // Select rows that are not active (isActive === false) - * df.select( !df("isActive") ) + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.filter( not(df.col("isActive")) ); * }} */ - override def unary_! : Column = constructColumn(null) { Not(expr) } + def unary_! : Column = exprToColumn(Not(expr)) /** - * Equality test with an expression. + * Equality test. * {{{ - * // The following two both select rows in which colA equals colB. - * df.select( df("colA") === df("colB") ) - * df.select( df("colA".equalTo(df("colB")) ) + * // Scala: + * df.filter( df("colA") === df("colB") ) + * + * // Java + * import static org.apache.spark.sql.Dsl.*; + * df.filter( col("colA").equalTo(col("colB")) ); * }}} */ - override def === (other: Column): Column = constructColumn(other) { - EqualTo(expr, other.expr) + def === (other: Any): Column = constructColumn(other) { o => + EqualTo(expr, o.expr) } /** - * Equality test with a literal value. - * {{{ - * // The following two both select rows in which colA is "Zaharia". - * df.select( df("colA") === "Zaharia") - * df.select( df("colA".equalTo("Zaharia") ) - * }}} - */ - override def === (literal: Any): Column = this === lit(literal) - - /** - * Equality test with an expression. - * {{{ - * // The following two both select rows in which colA equals colB. - * df.select( df("colA") === df("colB") ) - * df.select( df("colA".equalTo(df("colB")) ) - * }}} - */ - override def equalTo(other: Column): Column = this === other - - /** - * Equality test with a literal value. + * Equality test. * {{{ - * // The following two both select rows in which colA is "Zaharia". - * df.select( df("colA") === "Zaharia") - * df.select( df("colA".equalTo("Zaharia") ) + * // Scala: + * df.filter( df("colA") === df("colB") ) + * + * // Java + * import static org.apache.spark.sql.Dsl.*; + * df.filter( col("colA").equalTo(col("colB")) ); * }}} */ - override def equalTo(literal: Any): Column = this === literal + def equalTo(other: Any): Column = this === other /** - * Inequality test with an expression. + * Inequality test. * {{{ - * // The following two both select rows in which colA does not equal colB. + * // Scala: * df.select( df("colA") !== df("colB") ) * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.filter( not(col("colA").equalTo(col("colB"))) ); * }}} */ - override def !== (other: Column): Column = constructColumn(other) { - Not(EqualTo(expr, other.expr)) + def !== (other: Any): Column = constructColumn(other) { o => + Not(EqualTo(expr, o.expr)) } /** - * Inequality test with a literal value. - * {{{ - * // The following two both select rows in which colA does not equal equal 15. - * df.select( df("colA") !== 15 ) - * df.select( !(df("colA") === 15) ) - * }}} - */ - override def !== (literal: Any): Column = this !== lit(literal) - - /** - * Greater than an expression. + * Greater than. * {{{ - * // The following selects people older than 21. - * people.select( people("age") > Literal(21) ) + * // Scala: The following selects people older than 21. + * people.select( people("age") > 21 ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * people.select( people("age").gt(21) ); * }}} */ - override def > (other: Column): Column = constructColumn(other) { - GreaterThan(expr, other.expr) + def > (other: Any): Column = constructColumn(other) { o => + GreaterThan(expr, o.expr) } /** - * Greater than a literal value. + * Greater than. * {{{ - * // The following selects people older than 21. - * people.select( people("age") > 21 ) + * // Scala: The following selects people older than 21. + * people.select( people("age") > lit(21) ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * people.select( people("age").gt(21) ); * }}} */ - override def > (literal: Any): Column = this > lit(literal) + def gt(other: Any): Column = this > other /** - * Less than an expression. + * Less than. * {{{ - * // The following selects people younger than 21. - * people.select( people("age") < Literal(21) ) + * // Scala: The following selects people younger than 21. + * people.select( people("age") < 21 ) + * + * // Java: + * people.select( people("age").lt(21) ); * }}} */ - override def < (other: Column): Column = constructColumn(other) { - LessThan(expr, other.expr) + def < (other: Any): Column = constructColumn(other) { o => + LessThan(expr, o.expr) } /** - * Less than a literal value. + * Less than. * {{{ - * // The following selects people younger than 21. + * // Scala: The following selects people younger than 21. * people.select( people("age") < 21 ) + * + * // Java: + * people.select( people("age").lt(21) ); * }}} */ - override def < (literal: Any): Column = this < lit(literal) + def lt(other: Any): Column = this < other /** - * Less than or equal to an expression. + * Less than or equal to. * {{{ - * // The following selects people age 21 or younger than 21. - * people.select( people("age") <= Literal(21) ) + * // Scala: The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * + * // Java: + * people.select( people("age").leq(21) ); * }}} */ - override def <= (other: Column): Column = constructColumn(other) { - LessThanOrEqual(expr, other.expr) + def <= (other: Any): Column = constructColumn(other) { o => + LessThanOrEqual(expr, o.expr) } /** - * Less than or equal to a literal value. + * Less than or equal to. * {{{ - * // The following selects people age 21 or younger than 21. + * // Scala: The following selects people age 21 or younger than 21. * people.select( people("age") <= 21 ) + * + * // Java: + * people.select( people("age").leq(21) ); * }}} */ - override def <= (literal: Any): Column = this <= lit(literal) + def leq(other: Any): Column = this <= other /** * Greater than or equal to an expression. * {{{ - * // The following selects people age 21 or older than 21. - * people.select( people("age") >= Literal(21) ) + * // Scala: The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * + * // Java: + * people.select( people("age").geq(21) ) * }}} */ - override def >= (other: Column): Column = constructColumn(other) { - GreaterThanOrEqual(expr, other.expr) + def >= (other: Any): Column = constructColumn(other) { o => + GreaterThanOrEqual(expr, o.expr) } /** - * Greater than or equal to a literal value. + * Greater than or equal to an expression. * {{{ - * // The following selects people age 21 or older than 21. + * // Scala: The following selects people age 21 or older than 21. * people.select( people("age") >= 21 ) + * + * // Java: + * people.select( people("age").geq(21) ) * }}} */ - override def >= (literal: Any): Column = this >= lit(literal) + def geq(other: Any): Column = this >= other /** - * Equality test with an expression that is safe for null values. + * Equality test that is safe for null values. */ - override def <=> (other: Column): Column = constructColumn(other) { - other match { - case null => EqualNullSafe(expr, lit(null).expr) - case _ => EqualNullSafe(expr, other.expr) - } + def <=> (other: Any): Column = constructColumn(other) { o => + EqualNullSafe(expr, o.expr) } /** - * Equality test with a literal value that is safe for null values. + * Equality test that is safe for null values. */ - override def <=> (literal: Any): Column = this <=> lit(literal) + def eqNullSafe(other: Any): Column = this <=> other /** * True if the current expression is null. */ - override def isNull: Column = constructColumn(null) { IsNull(expr) } + def isNull: Column = exprToColumn(IsNull(expr)) /** * True if the current expression is NOT null. */ - override def isNotNull: Column = constructColumn(null) { IsNotNull(expr) } + def isNotNull: Column = exprToColumn(IsNotNull(expr)) /** - * Boolean OR with an expression. + * Boolean OR. * {{{ - * // The following selects people that are in school or employed. - * people.select( people("inSchool") || people("isEmployed") ) + * // Scala: The following selects people that are in school or employed. + * people.filter( people("inSchool") || people("isEmployed") ) + * + * // Java: + * people.filter( people("inSchool").or(people("isEmployed")) ); * }}} */ - override def || (other: Column): Column = constructColumn(other) { - Or(expr, other.expr) + def || (other: Any): Column = constructColumn(other) { o => + Or(expr, o.expr) } /** - * Boolean OR with a literal value. + * Boolean OR. * {{{ - * // The following selects everything. - * people.select( people("inSchool") || true ) + * // Scala: The following selects people that are in school or employed. + * people.filter( people("inSchool") || people("isEmployed") ) + * + * // Java: + * people.filter( people("inSchool").or(people("isEmployed")) ); * }}} */ - override def || (literal: Boolean): Column = this || lit(literal) + def or(other: Column): Column = this || other /** - * Boolean AND with an expression. + * Boolean AND. * {{{ - * // The following selects people that are in school and employed at the same time. + * // Scala: The following selects people that are in school and employed at the same time. * people.select( people("inSchool") && people("isEmployed") ) + * + * // Java: + * people.select( people("inSchool").and(people("isEmployed")) ); * }}} */ - override def && (other: Column): Column = constructColumn(other) { - And(expr, other.expr) + def && (other: Any): Column = constructColumn(other) { o => + And(expr, o.expr) } /** - * Boolean AND with a literal value. + * Boolean AND. * {{{ - * // The following selects people that are in school. - * people.select( people("inSchool") && true ) + * // Scala: The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * + * // Java: + * people.select( people("inSchool").and(people("isEmployed")) ); * }}} */ - override def && (literal: Boolean): Column = this && lit(literal) - - /** - * Bitwise AND with an expression. - */ - override def & (other: Column): Column = constructColumn(other) { - BitwiseAnd(expr, other.expr) - } - - /** - * Bitwise AND with a literal value. - */ - override def & (literal: Any): Column = this & lit(literal) - - /** - * Bitwise OR with an expression. - */ - override def | (other: Column): Column = constructColumn(other) { - BitwiseOr(expr, other.expr) - } - - /** - * Bitwise OR with a literal value. - */ - override def | (literal: Any): Column = this | lit(literal) - - /** - * Bitwise XOR with an expression. - */ - override def ^ (other: Column): Column = constructColumn(other) { - BitwiseXor(expr, other.expr) - } - - /** - * Bitwise XOR with a literal value. - */ - override def ^ (literal: Any): Column = this ^ lit(literal) + def and(other: Column): Column = this && other /** * Sum of this expression and another expression. * {{{ - * // The following selects the sum of a person's height and weight. + * // Scala: The following selects the sum of a person's height and weight. * people.select( people("height") + people("weight") ) + * + * // Java: + * people.select( people("height").plus(people("weight")) ); * }}} */ - override def + (other: Column): Column = constructColumn(other) { - Add(expr, other.expr) + def + (other: Any): Column = constructColumn(other) { o => + Add(expr, o.expr) } /** * Sum of this expression and another expression. * {{{ - * // The following selects the sum of a person's height and 10. - * people.select( people("height") + 10 ) + * // Scala: The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * + * // Java: + * people.select( people("height").plus(people("weight")) ); * }}} */ - override def + (literal: Any): Column = this + lit(literal) + def plus(other: Any): Column = this + other /** * Subtraction. Subtract the other expression from this expression. * {{{ - * // The following selects the difference between people's height and their weight. + * // Scala: The following selects the difference between people's height and their weight. * people.select( people("height") - people("weight") ) + * + * // Java: + * people.select( people("height").minus(people("weight")) ); * }}} */ - override def - (other: Column): Column = constructColumn(other) { - Subtract(expr, other.expr) + def - (other: Any): Column = constructColumn(other) { o => + Subtract(expr, o.expr) } /** - * Subtraction. Subtract a literal value from this expression. + * Subtraction. Subtract the other expression from this expression. * {{{ - * // The following selects a person's height and subtract it by 10. - * people.select( people("height") - 10 ) + * // Scala: The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * + * // Java: + * people.select( people("height").minus(people("weight")) ); * }}} */ - override def - (literal: Any): Column = this - lit(literal) + def minus(other: Any): Column = this - other /** * Multiplication of this expression and another expression. * {{{ - * // The following multiplies a person's height by their weight. + * // Scala: The following multiplies a person's height by their weight. * people.select( people("height") * people("weight") ) + * + * // Java: + * people.select( people("height").multiply(people("weight")) ); * }}} */ - override def * (other: Column): Column = constructColumn(other) { - Multiply(expr, other.expr) + def * (other: Any): Column = constructColumn(other) { o => + Multiply(expr, o.expr) } /** - * Multiplication this expression and a literal value. + * Multiplication of this expression and another expression. * {{{ - * // The following multiplies a person's height by 10. - * people.select( people("height") * 10 ) + * // Scala: The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * + * // Java: + * people.select( people("height").multiply(people("weight")) ); * }}} */ - override def * (literal: Any): Column = this * lit(literal) + def multiply(other: Any): Column = this * other /** * Division this expression by another expression. * {{{ - * // The following divides a person's height by their weight. + * // Scala: The following divides a person's height by their weight. * people.select( people("height") / people("weight") ) + * + * // Java: + * people.select( people("height").divide(people("weight")) ); * }}} */ - override def / (other: Column): Column = constructColumn(other) { - Divide(expr, other.expr) + def / (other: Any): Column = constructColumn(other) { o => + Divide(expr, o.expr) } /** - * Division this expression by a literal value. + * Division this expression by another expression. * {{{ - * // The following divides a person's height by 10. - * people.select( people("height") / 10 ) + * // Scala: The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * + * // Java: + * people.select( people("height").divide(people("weight")) ); * }}} */ - override def / (literal: Any): Column = this / lit(literal) + def divide(other: Any): Column = this / other /** * Modulo (a.k.a. remainder) expression. */ - override def % (other: Column): Column = constructColumn(other) { - Remainder(expr, other.expr) + def % (other: Any): Column = constructColumn(other) { o => + Remainder(expr, o.expr) } /** * Modulo (a.k.a. remainder) expression. */ - override def % (literal: Any): Column = this % lit(literal) - + def mod(other: Any): Column = this % other /** * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. */ @scala.annotation.varargs - override def in(list: Column*): Column = { + def in(list: Column*): Column = { new IncomputableColumn(In(expr, list.map(_.expr))) } - override def like(literal: String): Column = constructColumn(null) { - Like(expr, lit(literal).expr) - } + def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr)) - override def rlike(literal: String): Column = constructColumn(null) { - RLike(expr, lit(literal).expr) - } + def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr)) /** * An expression that gets an item at position `ordinal` out of an array. */ - override def getItem(ordinal: Int): Column = constructColumn(null) { - GetItem(expr, Literal(ordinal)) - } + def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal))) /** * An expression that gets a field by name in a [[StructField]]. */ - override def getField(fieldName: String): Column = constructColumn(null) { - GetField(expr, fieldName) - } + def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName)) /** * An expression that returns a substring. * @param startPos expression for the starting position. * @param len expression for the length of the substring. */ - override def substr(startPos: Column, len: Column): Column = { - new IncomputableColumn(Substring(expr, startPos.expr, len.expr)) - } + def substr(startPos: Column, len: Column): Column = + exprToColumn(Substring(expr, startPos.expr, len.expr), computable = false) /** * An expression that returns a substring. * @param startPos starting position. * @param len length of the substring. */ - override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len)) + def substr(startPos: Int, len: Int): Column = + exprToColumn(Substring(expr, lit(startPos).expr, lit(len).expr)) - override def contains(other: Column): Column = constructColumn(other) { - Contains(expr, other.expr) + def contains(other: Any): Column = constructColumn(other) { o => + Contains(expr, o.expr) } - override def contains(literal: Any): Column = this.contains(lit(literal)) - - override def startsWith(other: Column): Column = constructColumn(other) { - StartsWith(expr, other.expr) + def startsWith(other: Column): Column = constructColumn(other) { o => + StartsWith(expr, o.expr) } - override def startsWith(literal: String): Column = this.startsWith(lit(literal)) + def startsWith(literal: String): Column = this.startsWith(lit(literal)) - override def endsWith(other: Column): Column = constructColumn(other) { - EndsWith(expr, other.expr) + def endsWith(other: Column): Column = constructColumn(other) { o => + EndsWith(expr, o.expr) } - override def endsWith(literal: String): Column = this.endsWith(lit(literal)) + def endsWith(literal: String): Column = this.endsWith(lit(literal)) /** * Gives the column an alias. @@ -532,7 +531,7 @@ trait Column extends DataFrame with ExpressionApi { * df.select($"colA".as("colB")) * }}} */ - override def as(alias: String): Column = constructColumn(null) { Alias(expr, alias)() } + override def as(alias: String): Column = exprToColumn(Alias(expr, alias)()) /** * Casts the column to a different data type. @@ -545,7 +544,7 @@ trait Column extends DataFrame with ExpressionApi { * df.select(df("colA").cast("int")) * }}} */ - override def cast(to: DataType): Column = constructColumn(null) { Cast(expr, to) } + def cast(to: DataType): Column = exprToColumn(Cast(expr, to)) /** * Casts the column to a different data type, using the canonical string representation @@ -556,7 +555,7 @@ trait Column extends DataFrame with ExpressionApi { * df.select(df("colA").cast("int")) * }}} */ - override def cast(to: String): Column = constructColumn(null) { + def cast(to: String): Column = exprToColumn( Cast(expr, to.toLowerCase match { case "string" => StringType case "boolean" => BooleanType @@ -571,11 +570,11 @@ trait Column extends DataFrame with ExpressionApi { case "timestamp" => TimestampType case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") }) - } + ) - override def desc: Column = constructColumn(null) { SortOrder(expr, Descending) } + def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false) - override def asc: Column = constructColumn(null) { SortOrder(expr, Ascending) } + def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 385e1ec74f5f7..f3bc07ae5238c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.{List => JList} - import scala.reflect.ClassTag import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -27,6 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils private[sql] object DataFrame { @@ -73,7 +72,7 @@ private[sql] object DataFrame { * }}} */ // TODO: Improve documentation. -trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { +trait DataFrame extends RDDApi[Row] { val sqlContext: SQLContext @@ -82,7 +81,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { protected[sql] def logicalPlan: LogicalPlan - /** Left here for compatibility reasons. */ + /** Left here for backward compatibility. */ @deprecated("1.3.0", "use toDataFrame") def toSchemaRDD: DataFrame = this @@ -104,16 +103,16 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { def toDataFrame(colName: String, colNames: String*): DataFrame /** Returns the schema of this [[DataFrame]]. */ - override def schema: StructType + def schema: StructType /** Returns all column names and their data types as an array. */ - override def dtypes: Array[(String, String)] + def dtypes: Array[(String, String)] /** Returns all column names as an array. */ - override def columns: Array[String] = schema.fields.map(_.name) + def columns: Array[String] = schema.fields.map(_.name) /** Prints the schema to the console in a nice tree format. */ - override def printSchema(): Unit + def printSchema(): Unit /** * Cartesian join with another [[DataFrame]]. @@ -122,7 +121,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * * @param right Right side of the join operation. */ - override def join(right: DataFrame): DataFrame + def join(right: DataFrame): DataFrame /** * Inner join with another [[DataFrame]], using the given join expression. @@ -133,21 +132,27 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * df1.join(df2).where($"df1Key" === $"df2Key") * }}} */ - override def join(right: DataFrame, joinExprs: Column): DataFrame + def join(right: DataFrame, joinExprs: Column): DataFrame /** * Join with another [[DataFrame]], usin g the given join expression. The following performs * a full outer join between `df1` and `df2`. * * {{{ + * // Scala: + * import org.apache.spark.sql.dsl._ * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df1.join(df2, "outer", col("df1Key") === col("df2Key")); * }}} * * @param right Right side of the join. * @param joinExprs Join expression. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. */ - override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame /** * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. @@ -159,7 +164,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def sort(sortCol: String, sortCols: String*): DataFrame + def sort(sortCol: String, sortCols: String*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. For example: @@ -168,26 +173,31 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def sort(sortExpr: Column, sortExprs: Column*): DataFrame + def sort(sortExpr: Column, sortExprs: Column*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. */ @scala.annotation.varargs - override def orderBy(sortCol: String, sortCols: String*): DataFrame + def orderBy(sortCol: String, sortCols: String*): DataFrame /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. */ @scala.annotation.varargs - override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame /** * Selects column based on the column name and return it as a [[Column]]. */ - override def apply(colName: String): Column + def apply(colName: String): Column = col(colName) + + /** + * Selects column based on the column name and return it as a [[Column]]. + */ + def col(colName: String): Column /** * Selects a set of expressions, wrapped in a Product. @@ -197,12 +207,12 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * df.select($"colA", $"colB" + 1) * }}} */ - override def apply(projection: Product): DataFrame + def apply(projection: Product): DataFrame /** * Returns a new [[DataFrame]] with an alias set. */ - override def as(name: String): DataFrame + def as(name: String): DataFrame /** * Selects a set of expressions. @@ -211,7 +221,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def select(cols: Column*): DataFrame + def select(cols: Column*): DataFrame /** * Selects a set of columns. This is a variant of `select` that can only select @@ -224,7 +234,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def select(col: String, cols: String*): DataFrame + def select(col: String, cols: String*): DataFrame /** * Filters rows using the given condition. @@ -235,7 +245,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * peopleDf($"age" > 15) * }}} */ - override def filter(condition: Column): DataFrame + def filter(condition: Column): DataFrame /** * Filters rows using the given condition. This is an alias for `filter`. @@ -246,7 +256,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * peopleDf($"age" > 15) * }}} */ - override def where(condition: Column): DataFrame + def where(condition: Column): DataFrame /** * Filters rows using the given condition. This is a shorthand meant for Scala. @@ -257,7 +267,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * peopleDf($"age" > 15) * }}} */ - override def apply(condition: Column): DataFrame + def apply(condition: Column): DataFrame /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -275,7 +285,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def groupBy(cols: Column*): GroupedDataFrame + def groupBy(cols: Column*): GroupedDataFrame /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -296,27 +306,44 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }}} */ @scala.annotation.varargs - override def groupBy(col1: String, cols: String*): GroupedDataFrame + def groupBy(col1: String, cols: String*): GroupedDataFrame /** - * Aggregates on the entire [[DataFrame]] without groups. + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupBy().agg(aggExpr, aggExprs :_*) + } + + /** + * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. * {{ * // df.agg(...) is a shorthand for df.groupBy().agg(...) * df.agg(Map("age" -> "max", "salary" -> "avg")) * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} */ - override def agg(exprs: Map[String, String]): DataFrame + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** - * Aggregates on the entire [[DataFrame]] without groups. + * (Java-specific) Aggregates on the entire [[DataFrame]] without groups. * {{ * // df.agg(...) is a shorthand for df.groupBy().agg(...) * df.agg(Map("age" -> "max", "salary" -> "avg")) * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} */ - override def agg(exprs: java.util.Map[String, String]): DataFrame + def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) /** * Aggregates on the entire [[DataFrame]] without groups. @@ -327,31 +354,31 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * }} */ @scala.annotation.varargs - override def agg(expr: Column, exprs: Column*): DataFrame + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. */ - override def limit(n: Int): DataFrame + def limit(n: Int): DataFrame /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. * This is equivalent to `UNION ALL` in SQL. */ - override def unionAll(other: DataFrame): DataFrame + def unionAll(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. * This is equivalent to `INTERSECT` in SQL. */ - override def intersect(other: DataFrame): DataFrame + def intersect(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. * This is equivalent to `EXCEPT` in SQL. */ - override def except(other: DataFrame): DataFrame + def except(other: DataFrame): DataFrame /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -360,7 +387,7 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. */ - override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame /** * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. @@ -368,24 +395,26 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. */ - override def sample(withReplacement: Boolean, fraction: Double): DataFrame + def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } ///////////////////////////////////////////////////////////////////////////// /** * Returns a new [[DataFrame]] by adding a column. */ - override def addColumn(colName: String, col: Column): DataFrame + def addColumn(colName: String, col: Column): DataFrame /** * Returns the first `n` rows. */ - override def head(n: Int): Array[Row] + def head(n: Int): Array[Row] /** * Returns the first row. */ - override def head(): Row + def head(): Row /** * Returns the first row. Alias for head(). @@ -455,7 +484,17 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { /** * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. */ - override def rdd: RDD[Row] + def rdd: RDD[Row] + + /** + * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + */ + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + /** + * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + */ + def javaRDD: JavaRDD[Row] = toJavaRDD /** * Registers this RDD as a temporary table using the given name. The lifetime of this temporary @@ -463,14 +502,14 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * * @group schema */ - override def registerTempTable(tableName: String): Unit + def registerTempTable(tableName: String): Unit /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. */ - override def saveAsParquetFile(path: String): Unit + def saveAsParquetFile(path: String): Unit /** * :: Experimental :: @@ -483,19 +522,74 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] { * be the target of an `insertInto`. */ @Experimental - override def saveAsTable(tableName: String): Unit + def saveAsTable(tableName: String): Unit + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame based on a given data source and + * a set of options. This will fail if the table already exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + def saveAsTable( + tableName: String, + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame based on a given data source and + * a set of options. This will fail if the table already exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + def saveAsTable( + tableName: String, + dataSourceName: String, + options: java.util.Map[String, String]): Unit + + @Experimental + def save(path: String): Unit + + @Experimental + def save( + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit + + @Experimental + def save( + dataSourceName: String, + options: java.util.Map[String, String]): Unit /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. */ @Experimental - override def insertInto(tableName: String, overwrite: Boolean): Unit + def insertInto(tableName: String, overwrite: Boolean): Unit + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table. + * Throws an exception if the table already exists. + */ + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. */ - override def toJSON: RDD[String] + def toJSON: RDD[String] //////////////////////////////////////////////////////////////////////////// // for Python API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index f8fcc25569482..0b0623dc1fe75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.{List => JList} - import scala.language.implicitConversions import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -36,18 +34,22 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan} import org.apache.spark.sql.types.{NumericType, StructType} -import org.apache.spark.util.Utils /** - * See [[DataFrame]] for documentation. + * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly. */ private[sql] class DataFrameImpl protected[sql]( override val sqlContext: SQLContext, val queryExecution: SQLContext#QueryExecution) extends DataFrame { + /** + * A constructor that automatically analyzes the logical plan. This reports error eagerly + * as the [[DataFrame]] is constructed. + */ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { this(sqlContext, { val qe = sqlContext.executePlan(logicalPlan) @@ -145,7 +147,7 @@ private[sql] class DataFrameImpl protected[sql]( sort(sortExpr, sortExprs :_*) } - override def apply(colName: String): Column = colName match { + override def col(colName: String): Column = colName match { case "*" => Column(ResolvedStar(schema.fieldNames.map(resolve))) case _ => @@ -198,18 +200,6 @@ private[sql] class DataFrameImpl protected[sql]( new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) } - override def agg(exprs: Map[String, String]): DataFrame = { - groupBy().agg(exprs) - } - - override def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.toMap) - } - - override def agg(expr: Column, exprs: Column*): DataFrame = { - groupBy().agg(expr, exprs :_*) - } - override def limit(n: Int): DataFrame = { Limit(Literal(n), logicalPlan) } @@ -230,10 +220,6 @@ private[sql] class DataFrameImpl protected[sql]( Sample(fraction, withReplacement, seed, logicalPlan) } - override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { - sample(withReplacement, fraction, Utils.random.nextLong) - } - ///////////////////////////////////////////////////////////////////////////// override def addColumn(colName: String, col: Column): DataFrame = { @@ -303,8 +289,61 @@ private[sql] class DataFrameImpl protected[sql]( } override def saveAsTable(tableName: String): Unit = { - sqlContext.executePlan( - CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + val dataSourceName = sqlContext.conf.defaultDataSourceName + val cmd = + CreateTableUsingAsLogicalPlan( + tableName, + dataSourceName, + temporary = false, + Map.empty, + allowExisting = false, + logicalPlan) + + sqlContext.executePlan(cmd).toRdd + } + + override def saveAsTable( + tableName: String, + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit = { + val cmd = + CreateTableUsingAsLogicalPlan( + tableName, + dataSourceName, + temporary = false, + (option +: options).toMap, + allowExisting = false, + logicalPlan) + + sqlContext.executePlan(cmd).toRdd + } + + override def saveAsTable( + tableName: String, + dataSourceName: String, + options: java.util.Map[String, String]): Unit = { + val opts = options.toSeq + saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*) + } + + override def save(path: String): Unit = { + val dataSourceName = sqlContext.conf.defaultDataSourceName + save(dataSourceName, ("path" -> path)) + } + + override def save( + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit = { + ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this) + } + + override def save( + dataSourceName: String, + options: java.util.Map[String, String]): Unit = { + val opts = options.toSeq + save(dataSourceName, opts.head, opts.tail:_*) } override def insertInto(tableName: String, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index b4279a32ffa21..71365c776d559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -36,21 +36,6 @@ object Dsl { /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - // /** - // * An implicit conversion that turns a RDD of product into a [[DataFrame]]. - // * - // * This method requires an implicit SQLContext in scope. For example: - // * {{{ - // * implicit val sqlContext: SQLContext = ... - // * val rdd: RDD[(Int, String)] = ... - // * rdd.toDataFrame // triggers the implicit here - // * }}} - // */ - // implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) - // : DataFrame = { - // context.createDataFrame(rdd) - // } - /** Converts $"col name" into an [[Column]]. */ implicit class StringToColumn(val sc: StringContext) extends AnyVal { def $(args: Any*): ColumnName = { @@ -72,10 +57,16 @@ object Dsl { /** * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. */ def lit(literal: Any): Column = { - if (literal.isInstanceOf[Symbol]) { - return new ColumnName(literal.asInstanceOf[Symbol].name) + literal match { + case c: Column => return c + case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) + case _ => // continue } val literalExpr = literal match { @@ -100,27 +91,82 @@ object Dsl { Column(literalExpr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** Aggregate function: returns the sum of all values in the expression. */ def sum(e: Column): Column = Sum(e.expr) + + /** Aggregate function: returns the sum of distinct values in the expression. */ def sumDistinct(e: Column): Column = SumDistinct(e.expr) + + /** Aggregate function: returns the number of items in a group. */ def count(e: Column): Column = Count(e.expr) + /** Aggregate function: returns the number of distinct items in a group. */ @scala.annotation.varargs def countDistinct(expr: Column, exprs: Column*): Column = CountDistinct((expr +: exprs).map(_.expr)) + /** Aggregate function: returns the approximate number of distinct items in a group. */ def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) - def approxCountDistinct(e: Column, rsd: Double): Column = - ApproxCountDistinct(e.expr, rsd) + /** Aggregate function: returns the approximate number of distinct items in a group. */ + def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + + /** Aggregate function: returns the average of the values in a group. */ def avg(e: Column): Column = Average(e.expr) + + /** Aggregate function: returns the first value in a group. */ def first(e: Column): Column = First(e.expr) + + /** Aggregate function: returns the last value in a group. */ def last(e: Column): Column = Last(e.expr) + + /** Aggregate function: returns the minimum value of the expression in a group. */ def min(e: Column): Column = Min(e.expr) + + /** Aggregate function: returns the maximum value of the expression in a group. */ def max(e: Column): Column = Max(e.expr) + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * // Scala: + * df.select( -df("amount") ) + * + * // Java: + * df.select( negate(df.col("amount")) ); + * }}} + */ + def negate(e: Column): Column = -e + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * df.filter( not(df.col("isActive")) ); + * }} + */ + def not(e: Column): Column = !e + + /** Converts a string expression to upper case. */ def upper(e: Column): Column = Upper(e.expr) + + /** Converts a string exprsesion to lower case. */ def lower(e: Column): Column = Lower(e.expr) + + /** Computes the square root of the specified float value. */ def sqrt(e: Column): Column = Sqrt(e.expr) + + /** Computes the absolutle value. */ def abs(e: Column): Column = Abs(e.expr) /** @@ -131,6 +177,9 @@ object Dsl { cols.toList.toSeq } + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + // scalastyle:off /* Use the following code to generate: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index d3acd41bbf3eb..7963cb03126ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.{List => JList} - import scala.language.implicitConversions import scala.collection.JavaConversions._ @@ -30,8 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate /** * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. */ -class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) - extends GroupedDataFrameApi { +class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) { private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { @@ -60,19 +57,36 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr } /** - * Compute aggregates by specifying a map from column name to aggregate methods. The resulting - * [[DataFrame]] will also contain the grouping columns. + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + agg((aggExpr +: aggExprs).toMap) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * df.groupBy("department").agg(Map( - * "age" -> "max" - * "sum" -> "expense" + * "age" -> "max", + * "expense" -> "sum" * )) * }}} */ - override def agg(exprs: Map[String, String]): DataFrame = { + def agg(exprs: Map[String, String]): DataFrame = { exprs.map { case (colName, expr) => val a = strToExpr(expr)(df(colName).expr) Alias(a, a.toString)() @@ -80,16 +94,17 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr } /** - * Compute aggregates by specifying a map from column name to aggregate methods. The resulting - * [[DataFrame]] will also contain the grouping columns. + * (Java-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max" - * "sum" -> "expense" - * )) + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.builder() + * .put("age", "max") + * .put("expense", "sum") + * .build()); * }}} */ def agg(exprs: java.util.Map[String, String]): DataFrame = { @@ -104,12 +119,18 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr * * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: * import org.apache.spark.sql.dsl._ * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense"))); * }}} */ @scala.annotation.varargs - override def agg(expr: Column, exprs: Column*): DataFrame = { + def agg(expr: Column, exprs: Column*): DataFrame = { val aggExprs = (expr +: exprs).map(_.expr).map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.toString)() @@ -121,35 +142,35 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def mean(): DataFrame = aggregateNumericColumns(Average) + def mean(): DataFrame = aggregateNumericColumns(Average) /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def max(): DataFrame = aggregateNumericColumns(Max) + def max(): DataFrame = aggregateNumericColumns(Max) /** * Compute the mean value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def avg(): DataFrame = aggregateNumericColumns(Average) + def avg(): DataFrame = aggregateNumericColumns(Average) /** * Compute the min value for each numeric column for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def min(): DataFrame = aggregateNumericColumns(Min) + def min(): DataFrame = aggregateNumericColumns(Min) /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - override def sum(): DataFrame = aggregateNumericColumns(Sum) + def sum(): DataFrame = aggregateNumericColumns(Sum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 2f8c695d5654b..ba5c7355b4b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -72,7 +72,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err() - override def apply(colName: String): Column = err() + override def col(colName: String): Column = err() override def apply(projection: Product): DataFrame = err() @@ -90,12 +90,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def groupBy(col1: String, cols: String*): GroupedDataFrame = err() - override def agg(exprs: Map[String, String]): DataFrame = err() - - override def agg(exprs: java.util.Map[String, String]): DataFrame = err() - - override def agg(expr: Column, exprs: Column*): DataFrame = err() - override def limit(n: Int): DataFrame = err() override def unionAll(other: DataFrame): DataFrame = err() @@ -106,8 +100,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err() - override def sample(withReplacement: Boolean, fraction: Double): DataFrame = err() - ///////////////////////////////////////////////////////////////////////////// override def addColumn(colName: String, col: Column): DataFrame = err() @@ -152,6 +144,28 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def saveAsTable(tableName: String): Unit = err() + override def saveAsTable( + tableName: String, + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit = err() + + override def saveAsTable( + tableName: String, + dataSourceName: String, + options: java.util.Map[String, String]): Unit = err() + + override def save(path: String): Unit = err() + + override def save( + dataSourceName: String, + option: (String, String), + options: (String, String)*): Unit = err() + + override def save( + dataSourceName: String, + options: java.util.Map[String, String]): Unit = err() + override def insertInto(tableName: String, overwrite: Boolean): Unit = err() override def toJSON: RDD[String] = err() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala new file mode 100644 index 0000000000000..38e6382f171d5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait RDDApi[T] { + + def cache(): this.type = persist() + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type = unpersist(blocking = false) + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 243dc997078f3..7fe17944a734e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -33,6 +33,7 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" @@ -47,6 +48,9 @@ private[spark] object SQLConf { // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" + // This is used to set the default data source + val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -140,6 +144,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, we always treat INT96Values in Parquet files as timestamp. + */ + private[spark] def isParquetINT96AsTimestamp: Boolean = + getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ @@ -155,6 +165,9 @@ private[sql] class SQLConf extends Serializable { private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt + private[spark] def defaultDataSourceName: String = + getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f4692b3ff59d3..2697e780c05c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,6 +21,7 @@ import java.beans.Introspector import java.util.Properties import scala.collection.immutable +import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -37,7 +38,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.json._ import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation, DDLParser, DataSourceStrategy} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -86,7 +87,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) @transient protected[sql] lazy val analyzer: Analyzer = @@ -335,6 +336,29 @@ class SQLContext(@transient val sparkContext: SparkContext) applySchema(rowRDD, appliedSchema) } + @Experimental + def load(path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + load(dataSourceName, ("path", path)) + } + + @Experimental + def load( + dataSourceName: String, + option: (String, String), + options: (String, String)*): DataFrame = { + val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap) + DataFrame(this, LogicalRelation(resolved.relation)) + } + + @Experimental + def load( + dataSourceName: String, + options: java.util.Map[String, String]): DataFrame = { + val opts = options.toSeq + load(dataSourceName, opts.head, opts.tail:_*) + } + /** * :: Experimental :: * Construct an RDD representing the database table accessible via JDBC URL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala deleted file mode 100644 index eb0eb3f32560c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ /dev/null @@ -1,299 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.reflect.ClassTag - -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.storage.StorageLevel - - -/** - * An internal interface defining the RDD-like methods for [[DataFrame]]. - * Please use [[DataFrame]] directly, and do NOT use this. - */ -private[sql] trait RDDApi[T] { - - def cache(): this.type = persist() - - def persist(): this.type - - def persist(newLevel: StorageLevel): this.type - - def unpersist(): this.type = unpersist(blocking = false) - - def unpersist(blocking: Boolean): this.type - - def map[R: ClassTag](f: T => R): RDD[R] - - def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] - - def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] - - def foreach(f: T => Unit): Unit - - def foreachPartition(f: Iterator[T] => Unit): Unit - - def take(n: Int): Array[T] - - def collect(): Array[T] - - def collectAsList(): java.util.List[T] - - def count(): Long - - def first(): T - - def repartition(numPartitions: Int): DataFrame -} - - -/** - * An internal interface defining data frame related methods in [[DataFrame]]. - * Please use [[DataFrame]] directly, and do NOT use this. - */ -private[sql] trait DataFrameSpecificApi { - - def schema: StructType - - def printSchema(): Unit - - def dtypes: Array[(String, String)] - - def columns: Array[String] - - def head(): Row - - def head(n: Int): Array[Row] - - ///////////////////////////////////////////////////////////////////////////// - // Relational operators - ///////////////////////////////////////////////////////////////////////////// - def apply(colName: String): Column - - def apply(projection: Product): DataFrame - - @scala.annotation.varargs - def select(cols: Column*): DataFrame - - @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame - - def apply(condition: Column): DataFrame - - def as(name: String): DataFrame - - def filter(condition: Column): DataFrame - - def where(condition: Column): DataFrame - - @scala.annotation.varargs - def groupBy(cols: Column*): GroupedDataFrame - - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedDataFrame - - def agg(exprs: Map[String, String]): DataFrame - - def agg(exprs: java.util.Map[String, String]): DataFrame - - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame - - @scala.annotation.varargs - def sort(sortExpr: Column, sortExprs: Column*): DataFrame - - @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DataFrame - - @scala.annotation.varargs - def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame - - @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame - - def join(right: DataFrame): DataFrame - - def join(right: DataFrame, joinExprs: Column): DataFrame - - def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame - - def limit(n: Int): DataFrame - - def unionAll(other: DataFrame): DataFrame - - def intersect(other: DataFrame): DataFrame - - def except(other: DataFrame): DataFrame - - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame - - def sample(withReplacement: Boolean, fraction: Double): DataFrame - - ///////////////////////////////////////////////////////////////////////////// - // Column mutation - ///////////////////////////////////////////////////////////////////////////// - def addColumn(colName: String, col: Column): DataFrame - - ///////////////////////////////////////////////////////////////////////////// - // I/O and interaction with other frameworks - ///////////////////////////////////////////////////////////////////////////// - - def rdd: RDD[Row] - - def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() - - def toJSON: RDD[String] - - def registerTempTable(tableName: String): Unit - - def saveAsParquetFile(path: String): Unit - - @Experimental - def saveAsTable(tableName: String): Unit - - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit - - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - ///////////////////////////////////////////////////////////////////////////// - // Stat functions - ///////////////////////////////////////////////////////////////////////////// -// def describe(): Unit -// -// def mean(): Unit -// -// def max(): Unit -// -// def min(): Unit -} - - -/** - * An internal interface defining expression APIs for [[DataFrame]]. - * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this. - */ -private[sql] trait ExpressionApi { - - def isComputable: Boolean - - def unary_- : Column - def unary_! : Column - def unary_~ : Column - - def + (other: Column): Column - def + (other: Any): Column - def - (other: Column): Column - def - (other: Any): Column - def * (other: Column): Column - def * (other: Any): Column - def / (other: Column): Column - def / (other: Any): Column - def % (other: Column): Column - def % (other: Any): Column - def & (other: Column): Column - def & (other: Any): Column - def | (other: Column): Column - def | (other: Any): Column - def ^ (other: Column): Column - def ^ (other: Any): Column - - def && (other: Column): Column - def && (other: Boolean): Column - def || (other: Column): Column - def || (other: Boolean): Column - - def < (other: Column): Column - def < (other: Any): Column - def <= (other: Column): Column - def <= (other: Any): Column - def > (other: Column): Column - def > (other: Any): Column - def >= (other: Column): Column - def >= (other: Any): Column - def === (other: Column): Column - def === (other: Any): Column - def equalTo(other: Column): Column - def equalTo(other: Any): Column - def <=> (other: Column): Column - def <=> (other: Any): Column - def !== (other: Column): Column - def !== (other: Any): Column - - @scala.annotation.varargs - def in(list: Column*): Column - - def like(other: String): Column - def rlike(other: String): Column - - def contains(other: Column): Column - def contains(other: Any): Column - def startsWith(other: Column): Column - def startsWith(other: String): Column - def endsWith(other: Column): Column - def endsWith(other: String): Column - - def substr(startPos: Column, len: Column): Column - def substr(startPos: Int, len: Int): Column - - def isNull: Column - def isNotNull: Column - - def getItem(ordinal: Int): Column - def getField(fieldName: String): Column - - def cast(to: DataType): Column - def cast(to: String): Column - - def asc: Column - def desc: Column - - def as(alias: String): Column -} - - -/** - * An internal interface defining aggregation APIs for [[DataFrame]]. - * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this. - */ -private[sql] trait GroupedDataFrameApi { - - def agg(exprs: Map[String, String]): DataFrame - - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame - - def avg(): DataFrame - - def mean(): DataFrame - - def min(): DataFrame - - def max(): DataFrame - - def sum(): DataFrame - - def count(): DataFrame - - // TODO: Add var, std -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 391b3dae5c8ce..cad0667b46435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -215,22 +215,7 @@ private[sql] class StringColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends ColumnStats { - protected var upper: Date = null - protected var lower: Date = null - - override def gatherStats(row: Row, ordinal: Int) { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Date] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += DATE.defaultSize - } - } - - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) -} +private[sql] class DateColumnStats extends IntColumnStats private[sql] class TimestampColumnStats extends ColumnStats { protected var upper: Timestamp = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fcf2faa0914c0..db5bc0de363c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -335,21 +335,20 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } } -private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { +private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { override def extract(buffer: ByteBuffer) = { - val date = new Date(buffer.getLong()) - date + buffer.getInt } - override def append(v: Date, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime) + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) } override def getField(row: Row, ordinal: Int) = { - row(ordinal).asInstanceOf[Date] + row(ordinal).asInstanceOf[Int] } - override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row(ordinal) = value } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0cc9d049c9640..ff0609d4b3b72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.parquet._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{CreateTempTableUsing, CreateTableUsing} +import org.apache.spark.sql.sources._ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -314,12 +314,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, options) => + case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) => ExecutedCommand( - CreateTempTableUsing(tableName, userSpecifiedSchema, provider, options)) :: Nil - - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + CreateTempTableUsing( + tableName, userSpecifiedSchema, provider, opts)) :: Nil + case c: CreateTableUsing if !c.temporary => + sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + case c: CreateTableUsing if c.temporary && c.allowExisting => + sys.error("allowExisting should be set to false when creating a temporary table.") + + case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) => + val logicalPlan = sqlContext.parseSql(query) + val cmd = + CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan) + ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect if !c.temporary => + sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting => + sys.error("allowExisting should be set to false when creating a temporary table.") + + case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) => + val cmd = + CreateTempTableUsingAsSelect(tableName, provider, opts, query) + ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsLogicalPlan if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting => + sys.error("allowExisting should be set to false when creating a temporary table.") case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b85021acc9d4c..3a2f8d75dac5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (date: Int, DateType) => DateUtils.toJavaDate(date) + // Pyrolite can handle Timestamp and Decimal case (other, _) => other } @@ -171,7 +173,7 @@ object EvaluatePython { }): Row case (c: java.util.Calendar, DateType) => - new java.sql.Date(c.getTime().getTime()) + DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 1af96c28d5fdf..8372decbf8aa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -17,21 +17,26 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.SQLContext +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { +private[sql] class DefaultSource + extends RelationProvider with SchemaRelationProvider with CreateableRelationProvider { /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, None)(sqlContext) + JSONRelation(path, samplingRatio, None)(sqlContext) } /** Returns a new base relation with the given schema and parameters. */ @@ -39,21 +44,37 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { - val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) + JSONRelation(path, samplingRatio, Some(schema))(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) { + sys.error(s"path $path already exists.") + } + data.toJSON.saveAsTextFile(path) + + createRelation(sqlContext, parameters, data.schema) } } private[sql] case class JSONRelation( - fileName: String, + path: String, samplingRatio: Double, userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) - extends TableScan { + extends TableScan with InsertableRelation { - private def baseRDD = sqlContext.sparkContext.textFile(fileName) + // TODO: Support partitioned JSON relation. + private def baseRDD = sqlContext.sparkContext.textFile(path) override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( @@ -64,4 +85,24 @@ private[sql] case class JSONRelation( override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) + + override def insert(data: DataFrame, overwrite: Boolean) = { + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + if (overwrite) { + try { + fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to INSERT OVERWRITE a JSON table:\n${e.toString}") + } + data.toJSON.saveAsTextFile(path) + } else { + // TODO: Support INSERT INTO + sys.error("JSON table only support INSERT OVERWRITE for now.") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 9171939f7e8f7..33ce71b51b213 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -377,10 +377,12 @@ private[sql] object JsonRDD extends Logging { } } - private def toDate(value: Any): Date = { + private def toDate(value: Any): Int = { value match { // only support string as date - case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => + DateUtils.millisToDays(DataTypeConversions.stringToTime(value).getTime) + case value: java.sql.Date => DateUtils.fromJavaDate(value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 7c49b5220d607..02e5b015e8ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -38,4 +38,10 @@ package object sql { */ @DeveloperApi protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + + /** + * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + */ + @deprecated("1.3.0", "use DataFrame") + type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 10df8c3310092..d87ddfeabda77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.parquet +import java.sql.Timestamp +import java.util.{TimeZone, Calendar} + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import jodd.datetime.JDateTime import parquet.column.Dictionary import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType @@ -26,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ +import org.apache.spark.sql.parquet.timestamp.NanoTime /** * Collection of converters of Parquet types (group and primitive types) that @@ -123,6 +128,12 @@ private[sql] object CatalystConverter { parent.updateDecimal(fieldIndex, value, d) } } + case TimestampType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateTimestamp(fieldIndex, value) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -197,9 +208,11 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = updateField(fieldIndex, value) - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, readTimestamp(value)) + + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - } protected[parquet] def isRootConverter: Boolean = parent == null @@ -232,6 +245,13 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) dest.set(unscaled, precision, scale) } + + /** + * Read a Timestamp value from a Parquet Int96Value + */ + protected[parquet] def readTimestamp(value: Binary): Timestamp = { + CatalystTimestampConverter.convertToTimestamp(value) + } } /** @@ -384,6 +404,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = current.setString(fieldIndex, value) + override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, readTimestamp(value)) + override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { var decimal = current(fieldIndex).asInstanceOf[Decimal] @@ -454,6 +477,73 @@ private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } +private[parquet] object CatalystTimestampConverter { + // TODO most part of this comes from Hive-0.14 + // Hive code might have some issues, so we need to keep an eye on it. + // Also we use NanoTime and Int96Values from parquet-examples. + // We utilize jodd to convert between NanoTime and Timestamp + val parquetTsCalendar = new ThreadLocal[Calendar] + def getCalendar = { + // this is a cache for the calendar instance. + if (parquetTsCalendar.get == null) { + parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) + } + parquetTsCalendar.get + } + val NANOS_PER_SECOND: Long = 1000000000 + val SECONDS_PER_MINUTE: Long = 60 + val MINUTES_PER_HOUR: Long = 60 + val NANOS_PER_MILLI: Long = 1000000 + + def convertToTimestamp(value: Binary): Timestamp = { + val nt = NanoTime.fromBinary(value) + val timeOfDayNanos = nt.getTimeOfDayNanos + val julianDay = nt.getJulianDay + val jDateTime = new JDateTime(julianDay.toDouble) + val calendar = getCalendar + calendar.set(Calendar.YEAR, jDateTime.getYear) + calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) + calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) + + // written in command style + var remainder = timeOfDayNanos + calendar.set( + Calendar.HOUR_OF_DAY, + (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) + calendar.set( + Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) + calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) + val nanos = remainder % NANOS_PER_SECOND + val ts = new Timestamp(calendar.getTimeInMillis) + ts.setNanos(nanos.toInt) + ts + } + + def convertFromTimestamp(ts: Timestamp): Binary = { + val calendar = getCalendar + calendar.setTime(ts) + val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) + // Hive-0.14 didn't set hour before get day number, while the day number should + // has something to do with hour, since julian day number grows at 12h GMT + // here we just follow what hive does. + val julianDay = jDateTime.getJulianDayNumber + + val hour = calendar.get(Calendar.HOUR_OF_DAY) + val minute = calendar.get(Calendar.MINUTE) + val second = calendar.get(Calendar.SECOND) + val nanos = ts.getNanos + // Hive-0.14 would use hours directly, that might be wrong, since the day starts + // from 12h in Julian. here we just follow what hive does. + val nanosOfDay = nanos + second * NANOS_PER_SECOND + + minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + + hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR + NanoTime(julianDay, nanosOfDay).toBinary + } +} + /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index a54485e719dad..b0db9943a506c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -65,8 +65,8 @@ private[sql] case class ParquetRelation( ParquetTypesConverter.readSchemaFromFile( new Path(path.split(",").head), conf, - sqlContext.conf.isParquetBinaryAsString) - + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp) lazy val attributeMap = AttributeMap(output.map(o => o -> o)) override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fd63ad8144064..3fb1cc410521e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -83,7 +83,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) + schema = ParquetTypesConverter.convertToAttributes( + parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -184,12 +185,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + case _ => writePrimitive(schema.asInstanceOf[NativeType], value) } } } - private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { case StringType => writer.addBinary( @@ -202,6 +203,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) @@ -307,6 +309,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } + private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { + val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) + writer.addBinary(binaryNanoTime) + } } // Optimized for non-nested rows @@ -351,6 +357,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d5993656e0225..e4a10aa2ae6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} -import parquet.example.data.simple.SimpleGroup +import parquet.example.data.simple.{NanoTime, SimpleGroup} import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext @@ -63,6 +63,7 @@ private[sql] object ParquetTestData { optional int64 mylong; optional float myfloat; optional double mydouble; + optional int96 mytimestamp; }""" // field names for test assertion error messages @@ -72,7 +73,8 @@ private[sql] object ParquetTestData { "mystring:String", "mylong:Long", "myfloat:Float", - "mydouble:Double" + "mydouble:Double", + "mytimestamp:Timestamp" ) val subTestSchema = @@ -98,6 +100,7 @@ private[sql] object ParquetTestData { optional int64 myoptlong; optional float myoptfloat; optional double myoptdouble; + optional int96 mytimestamp; } """ @@ -236,6 +239,7 @@ private[sql] object ParquetTestData { record.add(3, i.toLong << 33) record.add(4, 2.5F) record.add(5, 4.5D) + record.add(6, new NanoTime(1,2)) writer.write(record) } writer.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 6d8c682ccced8..f1d4ff2387709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -54,7 +54,8 @@ private[parquet] object ParquetTypesConverter extends Logging { def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binaryAsString: Boolean): DataType = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean): DataType = { val originalType = parquetType.getOriginalType val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { @@ -66,6 +67,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.FLOAT => FloatType case ParquetPrimitiveTypeName.INT32 => IntegerType case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") @@ -103,7 +105,9 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { + def toDataType(parquetType: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -125,7 +129,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -137,9 +141,12 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } case ParquetOriginalType.MAP => { @@ -152,8 +159,10 @@ private[parquet] object ParquetTypesConverter extends Logging { "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } @@ -163,8 +172,10 @@ private[parquet] object ParquetTypesConverter extends Logging { val keyValueGroup = groupType.getFields.apply(0).asGroupType() assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } else if (correspondsToArray(groupType)) { // ArrayType @@ -172,16 +183,19 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype, isBinaryAsString), + toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -210,6 +224,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) case DecimalType.Fixed(precision, scale) if precision <= 18 => // TODO: for now, our writer only supports decimals that fit in a Long Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, @@ -345,7 +360,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -353,7 +370,7 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field, isBinaryAsString), + toDataType(field, isBinaryAsString, isInt96AsTimestamp), field.getRepetition != Repetition.REQUIRED)()) } @@ -476,7 +493,8 @@ private[parquet] object ParquetTypesConverter extends Logging { def readSchemaFromFile( origPath: Path, conf: Option[Configuration], - isBinaryAsString: Boolean): Seq[Attribute] = { + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -485,7 +503,9 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) + readMetaData(origPath, conf).getFileMetaData.getSchema, + isBinaryAsString, + isInt96AsTimestamp) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1e794cad73936..179c0d6b22239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -136,7 +136,8 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) ParquetTypesConverter.readSchemaFromFile( partitions.head.files.head.getPath, Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString)) + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) val dataIncludesKey = partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala new file mode 100644 index 0000000000000..887161684429f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet.timestamp + +import java.nio.{ByteBuffer, ByteOrder} + +import parquet.Preconditions +import parquet.io.api.{Binary, RecordConsumer} + +private[parquet] class NanoTime extends Serializable { + private var julianDay = 0 + private var timeOfDayNanos = 0L + + def set(julianDay: Int, timeOfDayNanos: Long) = { + this.julianDay = julianDay + this.timeOfDayNanos = timeOfDayNanos + this + } + + def getJulianDay: Int = julianDay + + def getTimeOfDayNanos: Long = timeOfDayNanos + + def toBinary: Binary = { + val buf = ByteBuffer.allocate(12) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + buf.flip() + Binary.fromByteBuffer(buf) + } + + def writeValue(recordConsumer: RecordConsumer) { + recordConsumer.addBinary(toBinary) + } + + override def toString = + "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" +} + +object NanoTime { + def fromBinary(bytes: Binary): NanoTime = { + Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") + val buf = bytes.toByteBuffer + buf.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + new NanoTime().set(julianDay, timeOfDayNanos) + } + + def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { + new NanoTime().set(julianDay, timeOfDayNanos) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index d13f2ce2a5e1d..386ff2452f1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, InsertIntoTable => LogicalInsertIntoTable} import org.apache.spark.sql.execution /** @@ -54,6 +54,13 @@ private[sql] object DataSourceStrategy extends Strategy { case l @ LogicalRelation(t: TableScan) => execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + case i @ LogicalInsertIntoTable( + l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => + if (partition.nonEmpty) { + sys.error(s"Insert into a partition is not allowed because $l is not partitioned.") + } + execution.ExecutedCommand(InsertIntoRelation(t, query, overwrite)) :: Nil + case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala new file mode 100644 index 0000000000000..d7942dc30934b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand + +private[sql] case class InsertIntoRelation( + relation: InsertableRelation, + query: LogicalPlan, + overwrite: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext) = { + relation.insert(DataFrame(sqlContext, query), overwrite) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index b1bbe0f89af73..ead827728cf4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -36,6 +36,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { try { Some(apply(input)) } catch { + case ddlException: DDLException => throw ddlException case _ if !exceptionOnError => None case x: Throwable => throw x } @@ -45,8 +46,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { lexical.initialize(reservedWords) phrase(dataType)(new lexical.Scanner(input)) match { case Success(r, x) => r - case x => - sys.error(s"Unsupported dataType: $x") + case x => throw new DDLException(s"Unsupported dataType: $x") } } @@ -56,8 +56,12 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + protected val AS = Keyword("AS") protected val COMMENT = Keyword("COMMENT") // Data types. @@ -83,22 +87,51 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected def start: Parser[LogicalPlan] = ddl /** - * `CREATE [TEMPORARY] TABLE avroTable + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... */ protected lazy val createTable: Parser[LogicalPlan] = ( - (CREATE ~> TEMPORARY.? <~ TABLE) ~ ident - ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case temp ~ tableName ~ columns ~ provider ~ opts => - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing(tableName, userSpecifiedSchema, provider, temp.isDefined, opts) - } + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident + ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + CreateTableUsingAsSelect(tableName, + provider, + temp.isDefined, + opts, + allowExisting.isDefined, + query.get) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableName, + userSpecifiedSchema, + provider, + temp.isDefined, + opts, + allowExisting.isDefined) + } + } ) protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" @@ -193,7 +226,7 @@ object ResolvedDataSource { dataSource .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case _ => + case dataSource: org.apache.spark.sql.sources.RelationProvider => sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") } } @@ -203,7 +236,7 @@ object ResolvedDataSource { dataSource .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case _ => + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") } } @@ -211,6 +244,32 @@ object ResolvedDataSource { new ResolvedDataSource(clazz, relation) } + + def apply( + sqlContext: SQLContext, + provider: String, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try loader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try loader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + + val relation = clazz.newInstance match { + case dataSource: org.apache.spark.sql.sources.CreateableRelationProvider => + dataSource + .asInstanceOf[org.apache.spark.sql.sources.CreateableRelationProvider] + .createRelation(sqlContext, options, data) + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + + new ResolvedDataSource(clazz, relation) + } } private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) @@ -220,13 +279,30 @@ private[sql] case class CreateTableUsing( userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, - options: Map[String, String]) extends Command + options: Map[String, String], + allowExisting: Boolean) extends Command + +private[sql] case class CreateTableUsingAsSelect( + tableName: String, + provider: String, + temporary: Boolean, + options: Map[String, String], + allowExisting: Boolean, + query: String) extends Command + +private[sql] case class CreateTableUsingAsLogicalPlan( + tableName: String, + provider: String, + temporary: Boolean, + options: Map[String, String], + allowExisting: Boolean, + query: LogicalPlan) extends Command private [sql] case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends RunnableCommand { def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) @@ -236,6 +312,22 @@ private [sql] case class CreateTempTableUsing( } } +private [sql] case class CreateTempTableUsingAsSelect( + tableName: String, + provider: String, + options: Map[String, String], + query: LogicalPlan) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val df = DataFrame(sqlContext, query) + val resolved = ResolvedDataSource(sqlContext, provider, options, df) + sqlContext.registerRDDAsTable( + DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + + Seq.empty + } +} + /** * Builds a map in which keys are case insensitive */ @@ -253,3 +345,9 @@ protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, override def -(key: String): Map[String, String] = baseMap - key.toLowerCase() } + +/** + * The exception thrown from the DDL parser. + * @param message + */ +protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index cd82cc6ecb61b..ad0a35b91ebc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} import org.apache.spark.sql.types.StructType @@ -77,6 +77,14 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +@DeveloperApi +trait CreateableRelationProvider { + def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + data: DataFrame): BaseRelation +} + /** * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must @@ -108,7 +116,7 @@ abstract class BaseRelation { * A BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi -abstract class TableScan extends BaseRelation { +trait TableScan extends BaseRelation { def buildScan(): RDD[Row] } @@ -118,7 +126,7 @@ abstract class TableScan extends BaseRelation { * containing all of its tuples as Row objects. */ @DeveloperApi -abstract class PrunedScan extends BaseRelation { +trait PrunedScan extends BaseRelation { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -132,7 +140,7 @@ abstract class PrunedScan extends BaseRelation { * as filtering partitions based on a bloom filter. */ @DeveloperApi -abstract class PrunedFilteredScan extends BaseRelation { +trait PrunedFilteredScan extends BaseRelation { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } @@ -145,6 +153,11 @@ abstract class PrunedFilteredScan extends BaseRelation { * for experimentation. */ @Experimental -abstract class CatalystScan extends BaseRelation { +trait CatalystScan extends BaseRelation { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } + +@DeveloperApi +trait InsertableRelation extends BaseRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 006b16fbe07bd..e6f622e87f7a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.sql_tests.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e18ba287e8683..0501b47f080d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -38,7 +38,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { var origZone: TimeZone = _ override protected def beforeAll() { origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) } override protected def afterAll() { @@ -143,26 +143,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( - "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' - AND time>'1970-01-01 00:00:00.001'"""), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) + """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' + AND time>'1969-12-31 16:00:00.001'"""), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + "SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"), + Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -296,6 +296,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("date row") { + checkAnswer(sql( + """select cast("2015-01-28" as date) from testData limit 1"""), + Row(java.sql.Date.valueOf("2015-01-28")) + ) + } + test("from follow multiple brackets") { checkAnswer(sql( "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index a015884bae282..f26fcc0385b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -83,7 +83,8 @@ class ScalaReflectionRelationSuite extends FunSuite { assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) + new java.math.BigDecimal(1), new Date(70, 0, 1), // This is 1970-01-01 + new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index be2b34de077c9..581fccf8ee613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -30,7 +30,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) def testColumnStats[T <: NativeType, U <: ColumnStats]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 87e608a8853dc..9ce845912f1c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.scalatest.FunSuite @@ -34,7 +34,7 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, - STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -64,7 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, new Date(0L), 8) + checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f941465fa3e35..60ed28cc97bf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{DataType, NativeType} @@ -50,7 +50,7 @@ object ColumnarTestUtils { case STRING => Random.nextString(Random.nextInt(32)) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => new Date(Random.nextLong()) + case DATE => Random.nextInt() case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index cb615388da0c7..1396c6b7246d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -67,14 +67,15 @@ class JsonSuite extends QueryTest { checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" - checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + checkTypePromotion( + DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3d82f4bce7778..5ec7a156d9353 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -37,11 +37,21 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) withParquetTable(data, "t") { - sql("INSERT INTO t SELECT * FROM t") + sql("INSERT INTO TABLE t SELECT * FROM t") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } } + // This test case will trigger the NPE mentioned in + // https://issues.apache.org/jira/browse/PARQUET-151. + ignore("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + } + test("self-join") { // 4 rows, cells of column 1 of row 2 and row 4 are null val data = (1 to 4).map { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala new file mode 100644 index 0000000000000..b02389978b625 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.util +import org.apache.spark.util.Utils + +class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var path: File = null + + override def beforeAll(): Unit = { + path = util.getTempFilePath("jsonCTAS").getCanonicalFile + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + } + + override def afterAll(): Unit = { + dropTempTable("jt") + } + + after { + if (path.exists()) Utils.deleteRecursively(path) + } + + test("CREATE TEMPORARY TABLE AS SELECT") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) + + dropTempTable("jsonTable") + } + + test("create a table, drop it and create another one with the same name") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) + + dropTempTable("jsonTable") + + val message = intercept[RuntimeException]{ + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains(s"path ${path.toString} already exists."), + "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") + + // Explicitly delete it. + if (path.exists()) Utils.deleteRecursively(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) + + dropTempTable("jsonTable") + } + + test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { + val message = intercept[DDLException]{ + sql( + s""" + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT b FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), + "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") + } + + test("a CTAS statement with column definitions is not allowed") { + intercept[DDLException]{ + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala new file mode 100644 index 0000000000000..f91cea6a37060 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util +import org.apache.spark.util.Utils + +class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var path: File = null + + override def beforeAll: Unit = { + path = util.getTempFilePath("jsonCTAS").getCanonicalFile + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) + """.stripMargin) + } + + override def afterAll: Unit = { + dropTempTable("jsonTable") + dropTempTable("jt") + if (path.exists()) Utils.deleteRecursively(path) + } + + test("Simple INSERT OVERWRITE a JSONRelation") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("INSERT OVERWRITE a JSONRelation multiple times") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("INSERT INTO not supported for JSONRelation for now") { + intercept[RuntimeException]{ + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala new file mode 100644 index 0000000000000..fe2f76cc397f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils + +import org.apache.spark.sql.catalyst.util + +class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var originalDefaultSource: String = null + + var path: File = null + + var df: DataFrame = null + + override def beforeAll(): Unit = { + originalDefaultSource = conf.defaultDataSourceName + conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") + + path = util.getTempFilePath("datasource").getCanonicalFile + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + df = jsonRDD(rdd) + } + + override def afterAll(): Unit = { + conf.setConf("spark.sql.default.datasource", originalDefaultSource) + } + + after { + if (path.exists()) Utils.deleteRecursively(path) + } + + def checkLoad(): Unit = { + checkAnswer(load(path.toString), df.collect()) + checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect()) + } + + test("save with overwrite and load") { + df.save(path.toString) + checkLoad + } + + test("save with data source and options, and load") { + df.save("org.apache.spark.sql.json", ("path", path.toString)) + checkLoad + } + + test("save and save again") { + df.save(path.toString) + + val message = intercept[RuntimeException] { + df.save(path.toString) + }.getMessage + + assert( + message.contains("already exists"), + "We should complain that the path already exists.") + + if (path.exists()) Utils.deleteRecursively(path) + + df.save(path.toString) + checkLoad + } +} \ No newline at end of file diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 7c2c5feef26f6..2f5e9be6f8d6b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -354,6 +354,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "database_drop", "database_location", "database_properties", + "date_1", "date_2", "date_3", "date_4", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5efc3b1e30774..1921bf6e5e1a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.DataSourceStrategy +import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy} import org.apache.spark.sql.types._ /** @@ -86,6 +86,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @param allowExisting When false, an exception will be thrown if the table already exists. * @tparam A A case class that is used to describe the schema of the table to be created. */ + @Deprecated def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } @@ -106,6 +107,70 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.invalidateTable("default", tableName) } + @Experimental + def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = { + val dataSourceName = conf.defaultDataSourceName + createTable(tableName, dataSourceName, allowExisting, ("path", path)) + } + + @Experimental + def createTable( + tableName: String, + dataSourceName: String, + allowExisting: Boolean, + option: (String, String), + options: (String, String)*): Unit = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + dataSourceName, + temporary = false, + (option +: options).toMap, + allowExisting) + executePlan(cmd).toRdd + } + + @Experimental + def createTable( + tableName: String, + dataSourceName: String, + schema: StructType, + allowExisting: Boolean, + option: (String, String), + options: (String, String)*): Unit = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + dataSourceName, + temporary = false, + (option +: options).toMap, + allowExisting) + executePlan(cmd).toRdd + } + + @Experimental + def createTable( + tableName: String, + dataSourceName: String, + allowExisting: Boolean, + options: java.util.Map[String, String]): Unit = { + val opts = options.toSeq + createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*) + } + + @Experimental + def createTable( + tableName: String, + dataSourceName: String, + schema: StructType, + allowExisting: Boolean, + options: java.util.Map[String, String]): Unit = { + val opts = options.toSeq + createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*) + } + /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. @@ -246,7 +311,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry + new HiveFunctionRegistry with OverrideFunctionRegistry { + def caseSensitive = false + } /* An analyzer that uses the Hive metastore. */ @transient @@ -256,6 +323,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: + ResolveUdtfsAlias :: Nil } @@ -410,7 +478,7 @@ private object HiveContext { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString + case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 82dba99900df9..4afa2e71d77cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -267,7 +267,8 @@ private[hive] trait HiveInspectors { val temp = new Array[Byte](writable.getLength) System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp - case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() + case poi: WritableConstantDateObjectInspector => + DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data mi.getWritableConstantValue.map { case (k, v) => @@ -304,7 +305,8 @@ private[hive] trait HiveInspectors { System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).get() + DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object // if next timestamp is null, so Timestamp object is cloned case x: TimestampObjectInspector if x.preferWritable() => @@ -343,6 +345,9 @@ private[hive] trait HiveInspectors { case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + case _: JavaDateObjectInspector => + (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { @@ -426,7 +431,7 @@ private[hive] trait HiveInspectors { case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) - case _: DateObjectInspector => a.asInstanceOf[java.sql.Date] + case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d910ee950904d..48bea6c1bd685 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -23,10 +23,9 @@ import java.util.{List => JList} import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException} -import org.apache.hadoop.hive.ql.metadata.InvalidTableException +import org.apache.hadoop.hive.metastore.{Warehouse, TableType} +import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, AlreadyExistsException, FieldSchema} +import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} @@ -52,6 +51,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) + protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) + // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) { @@ -99,11 +100,22 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + /** * + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + * @param tableName + * @param userSpecifiedSchema + * @param provider + * @param options + * @param isExternal + * @return + */ def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) = { + options: Map[String, String], + isExternal: Boolean): Unit = { val (dbName, tblName) = processDatabaseAndTableName("default", tableName) val tbl = new Table(dbName, tblName) @@ -113,8 +125,13 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } options.foreach { case (key, value) => tbl.setSerdeParam(key, value) } - tbl.setProperty("EXTERNAL", "TRUE") - tbl.setTableType(TableType.EXTERNAL_TABLE) + if (isExternal) { + tbl.setProperty("EXTERNAL", "TRUE") + tbl.setTableType(TableType.EXTERNAL_TABLE) + } else { + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + } // create the table synchronized { @@ -122,6 +139,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + def hiveDefaultTableFilePath(tableName: String): String = { + hiveWarehouse.getTablePath(client.getDatabaseCurrent, tableName).toString + } + def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2bf3635a54442..62f2e7f4d659f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -984,14 +984,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", - e :: Nil) => + case Token("TOK_SELEXPR", e :: Nil) => Some(nodeToExpr(e)) - case Token("TOK_SELEXPR", - e :: Token(alias, Nil) :: Nil) => + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + case Token("TOK_SELEXPR", e :: aliasChildren) => + var aliasNames = ArrayBuffer[String]() + aliasChildren.foreach { _ match { + case Token(name, Nil) => aliasNames += cleanIdentifier(name) + case _ => + } + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index fa997288a2848..d89111094b9ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeComman import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.CreateTableUsing +import org.apache.spark.sql.sources.{CreateTableUsingAsLogicalPlan, CreateTableUsingAsSelect, CreateTableUsing} import org.apache.spark.sql.types.StringType @@ -212,9 +212,21 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) => ExecutedCommand( - CreateMetastoreDataSource(tableName, userSpecifiedSchema, provider, options)) :: Nil + CreateMetastoreDataSource( + tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil + + case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) => + val logicalPlan = hiveContext.parseSql(query) + val cmd = + CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan) + ExecutedCommand(cmd) :: Nil + + case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) => + val cmd = + CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query) + ExecutedCommand(cmd) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c368715f7c6f5..effaa5a443512 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -34,6 +34,7 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DateUtils /** * A trait for subclasses that handle table scans. @@ -306,7 +307,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value)) + row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 4814cb7ebfe51..95dcaccefdc54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.ResolvedDataSource +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StructType @@ -102,11 +104,77 @@ case class CreateMetastoreDataSource( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String], + allowExisting: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - hiveContext.catalog.createDataSourceTable(tableName, userSpecifiedSchema, provider, options) + + if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (allowExisting) { + return Seq.empty[Row] + } else { + sys.error(s"Table $tableName already exists.") + } + } + + var isExternal = true + val optionsWithPath = + if (!options.contains("path")) { + isExternal = false + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + } else { + options + } + + hiveContext.catalog.createDataSourceTable( + tableName, + userSpecifiedSchema, + provider, + optionsWithPath, + isExternal) + + Seq.empty[Row] + } +} + +case class CreateMetastoreDataSourceAsSelect( + tableName: String, + provider: String, + options: Map[String, String], + allowExisting: Boolean, + query: LogicalPlan) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + + if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (allowExisting) { + return Seq.empty[Row] + } else { + sys.error(s"Table $tableName already exists.") + } + } + + val df = DataFrame(hiveContext, query) + var isExternal = true + val optionsWithPath = + if (!options.contains("path")) { + isExternal = false + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + } else { + options + } + + // Create the relation based on the data of df. + ResolvedDataSource(sqlContext, provider, optionsWithPath, df) + + hiveContext.catalog.createDataSourceTable( + tableName, + None, + provider, + optionsWithPath, + isExternal) Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 76d2140372197..34c21c11761ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -33,8 +33,11 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils.getContextOrSparkClassLoader +import org.apache.spark.sql.catalyst.analysis.MultiAlias +import org.apache.spark.sql.catalyst.errors.TreeNodeException /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -321,6 +324,20 @@ private[hive] case class HiveGenericUdtf( override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } +/** + * Resolve Udtfs Alias. + */ +private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan) = plan transform { + case p @ Project(projectList, _) + if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => + throw new TreeNodeException(p, "only single Generator supported for SELECT clause") + + case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) => + Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child) + } +} + private[hive] case class HiveUdafFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d new file mode 100644 index 0000000000000..98da82fa89386 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d @@ -0,0 +1 @@ +1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 rename to sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 rename to sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea rename to sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b rename to sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b rename to sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b rename to sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a rename to sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d rename to sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d rename to sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2d3ff680125ad..09bbd5c867e4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.util -import java.sql.Date import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.ql.udf.UDAFPercentile @@ -76,7 +75,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(0.asInstanceOf[Float]) :: Literal(0.asInstanceOf[Double]) :: Literal("0") :: - Literal(new Date(2014, 9, 23)) :: + Literal(new java.sql.Date(114, 8, 23)) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: @@ -143,7 +142,6 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } - case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0) case (r1, r2) => assert(r1 === r2) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7408c7ffd69e8..85795acb658e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,7 +22,9 @@ import java.io.File import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ import org.apache.spark.util.Utils import org.apache.spark.sql.types._ @@ -36,9 +38,11 @@ import org.apache.spark.sql.hive.test.TestHive._ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { override def afterEach(): Unit = { reset() + if (ctasPath.exists()) Utils.deleteRecursively(ctasPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile test ("persistent JSON table") { sql( @@ -94,7 +98,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructField("", innerStruct, true) :: StructField("b", StringType, true) :: Nil) - assert(expectedSchema == table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) jsonFile(filePath).registerTempTable("expectedJsonTable") @@ -137,6 +141,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { intercept[Exception] { sql("SELECT * FROM jsonTable").collect() } + + assert( + (new File(filePath)).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") } test("check change without refresh") { @@ -240,7 +249,144 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { invalidateTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema == table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + + test("CTAS") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${ctasPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + + test("CTAS with IF NOT EXISTS") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${ctasPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AlreadyExistsException. + val message = intercept[RuntimeException] { + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${ctasPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + assert(message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s""" + |CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${ctasPath}' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + + test("CTAS a managed table") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + new Path("/Users/yhuai/Desktop/whatever") + + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") + + sql( + s""" + |CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${expectedPath}' + |) + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("loadedTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable").collect() + ) + + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { @@ -255,4 +401,39 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("DROP TABLE jsonTable").collect().foreach(println) } + + test("save and load table") { + val originalDefaultSource = conf.defaultDataSourceName + conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + df.saveAsTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + df.collect()) + + val message = intercept[RuntimeException] { + createTable("createdJsonTable", filePath.toString, false) + }.getMessage + assert(message.contains("Table createdJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + createTable("createdJsonTable", filePath.toString, true) + // createdJsonTable should be not changed. + assert(table("createdJsonTable").schema === df.schema) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + df.collect()) + + conf.setConf("spark.sql.default.datasource", originalDefaultSource) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala new file mode 100644 index 0000000000000..85b6bc93d7122 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +/* Implicits */ + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHive._ + +case class FunctionResult(f1: String, f2: String) + +class UDFSuite extends QueryTest { + test("UDF case insensitive") { + udf.register("random0", () => { Math.random()}) + udf.register("RANDOM1", () => { Math.random()}) + udf.register("strlenScala", (_: String).length + (_:Int)) + assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4c53b10ba96e9..82efadb28e890 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -253,8 +253,30 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Cast Timestamp to Timestamp in UDF", """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) - | FROM src LIMIT 1 + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 1", + """ + | SELECT + | CAST(CAST('1970-01-01 22:00:00' AS timestamp) AS date) == + | CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 2", + "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") + + createQueryTest("Date cast", + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 """.stripMargin) createQueryTest("Simple Average", @@ -583,6 +605,18 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5383 alias for udfs with multi output columns") { + assert( + sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + .collect() + .size == 5) + + assert( + sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + .collect() + .size == 5) + } + test("SPARK-5367: resolve star expression in udf") { assert(sql("select concat(*) from src limit 5").collect().size == 5) assert(sql("select array(*) from src limit 5").collect().size == 5) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index eb7a7750af02d..4efe0c5e0cd44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -170,7 +170,7 @@ class SQLQuerySuite extends QueryTest { sql("CREATE TABLE test2 (key INT, value STRING)") testData.insertInto("test2") testData.insertInto("test2") - sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").saveAsTable("test") + sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 254919e8f6fdc..b5a0754ff61f9 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -160,7 +160,7 @@ private[hive] object HiveShim { if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 45ca59ae56a38..e4c1809c8bb21 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -263,7 +263,7 @@ private[hive] object HiveShim { } def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0e0f5bd3b9db4..b3ffc71904c76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -73,7 +73,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logDebug("Stopping JobScheduler") // First, stop receiving - receiverTracker.stop() + receiverTracker.stop(processAllReceivedData) // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 4f998869731ed..00456ab2a0c92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -86,10 +86,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Stop the receiver execution thread. */ - def stop() = synchronized { + def stop(graceful: Boolean) = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop(graceful) // Finally, stop the actor ssc.env.actorSystem.stop(actor) @@ -218,6 +218,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env + @volatile @transient private var running = false @transient val thread = new Thread() { override def run() { try { @@ -233,7 +234,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop() { + def stop(graceful: Boolean) { // Send the stop signal to all the receivers stopReceivers() @@ -241,6 +242,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // That is, for the receivers to quit gracefully. thread.join(10000) + if (graceful) { + val pollTime = 100 + def done = { receiverInfo.isEmpty && !running } + logInfo("Waiting for receiver job to terminate gracefully") + while(!done) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + // Check if all the receivers have been deregistered or not if (!receiverInfo.isEmpty) { logWarning("All of the receivers have not deregistered, " + receiverInfo) @@ -295,7 +306,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") + running = true ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + running = false logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 9f352bdcb0893..0b5af25e0f7cc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -205,6 +205,32 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("stop slow receiver gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.streaming.gracefulStopTimeout", "20000") + sc = new SparkContext(conf) + logInfo("==================================\n\n\n") + ssc = new StreamingContext(sc, Milliseconds(100)) + var runningCount = 0 + SlowTestReceiver.receivedAllRecords = false + //Create test receiver that sleeps in onStop() + val totalNumRecords = 15 + val recordsPerSecond = 1 + val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) + input.count().foreachRDD { rdd => + val count = rdd.first() + runningCount += count.toInt + logInfo("Count = " + count + ", Running count = " + runningCount) + } + ssc.start() + ssc.awaitTermination(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + assert(runningCount > 0) + assert(runningCount == totalNumRecords) + Thread.sleep(100) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -319,6 +345,38 @@ object TestReceiver { val counter = new AtomicInteger(1) } +/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + logInfo("Receiving started") + for(i <- 1 to totalRecords) { + Thread.sleep(1000 / recordsPerSecond) + store(i) + } + SlowTestReceiver.receivedAllRecords = true + logInfo(s"Received all $totalRecords records") + } + } + receivingThreadOption = Some(thread) + thread.start() + } + + def onStop() { + // Simulate slow receiver by waiting for all records to be produced + while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + // no cleanup to be done, the receiving thread should stop on it own + } +} + +object SlowTestReceiver { + var receivedAllRecords = false +} + /** Streaming application for testing DStream and RDD creation sites */ package object testPackage extends Assertions { def test() {